diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | .hlint.yaml | 17 | ||||
| -rw-r--r-- | README.md | 5 | ||||
| -rw-r--r-- | bench/Main.hs | 71 | ||||
| -rw-r--r-- | chad-fast.cabal | 86 | ||||
| -rw-r--r-- | example/Main.hs | 2 | ||||
| -rw-r--r-- | rules/.gitignore | 1 | ||||
| -rw-r--r-- | src/AST/Count.hs | 169 | ||||
| -rw-r--r-- | src/AST/Env.hs | 59 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 145 | ||||
| -rw-r--r-- | src/CHAD.hs | 1159 | ||||
| -rw-r--r-- | src/CHAD/APIv1.hs | 178 | ||||
| -rw-r--r-- | src/CHAD/AST.hs (renamed from src/AST.hs) | 288 | ||||
| -rw-r--r-- | src/CHAD/AST/Accum.hs (renamed from src/AST/Accum.hs) | 81 | ||||
| -rw-r--r-- | src/CHAD/AST/Bindings.hs (renamed from src/AST/Bindings.hs) | 21 | ||||
| -rw-r--r-- | src/CHAD/AST/Count.hs | 927 | ||||
| -rw-r--r-- | src/CHAD/AST/Env.hs | 95 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs (renamed from src/AST/Pretty.hs) | 97 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse.hs | 296 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs | 107 | ||||
| -rw-r--r-- | src/CHAD/AST/SplitLets.hs (renamed from src/AST/SplitLets.hs) | 56 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs (renamed from src/AST/Types.hs) | 36 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs | 252 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken.hs (renamed from src/AST/Weaken.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken/Auto.hs (renamed from src/AST/Weaken/Auto.hs) | 57 | ||||
| -rw-r--r-- | src/CHAD/Accum.hs | 27 | ||||
| -rw-r--r-- | src/CHAD/Analysis/Identity.hs (renamed from src/Analysis/Identity.hs) | 72 | ||||
| -rw-r--r-- | src/CHAD/Array.hs (renamed from src/Array.hs) | 12 | ||||
| -rw-r--r-- | src/CHAD/Compile.hs (renamed from src/Compile.hs) | 620 | ||||
| -rw-r--r-- | src/CHAD/Compile/Exec.hs (renamed from src/Compile/Exec.hs) | 16 | ||||
| -rw-r--r-- | src/CHAD/Data.hs (renamed from src/Data.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/Data/VarMap.hs (renamed from src/Data/VarMap.hs) | 17 | ||||
| -rw-r--r-- | src/CHAD/Drev.hs | 1581 | ||||
| -rw-r--r-- | src/CHAD/Drev/Accum.hs | 72 | ||||
| -rw-r--r-- | src/CHAD/Drev/EnvDescr.hs (renamed from src/CHAD/EnvDescr.hs) | 34 | ||||
| -rw-r--r-- | src/CHAD/Drev/Top.hs (renamed from src/CHAD/Top.hs) | 81 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs (renamed from src/CHAD/Types.hs) | 51 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs (renamed from src/CHAD/Types/ToTan.hs) | 32 | ||||
| -rw-r--r-- | src/CHAD/Example.hs (renamed from src/Example.hs) | 58 | ||||
| -rw-r--r-- | src/CHAD/Example/GMM.hs (renamed from src/Example/GMM.hs) | 7 | ||||
| -rw-r--r-- | src/CHAD/Example/Types.hs (renamed from src/Example/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD.hs (renamed from src/ForwardAD.hs) | 29 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers.hs (renamed from src/ForwardAD/DualNumbers.hs) | 24 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers/Types.hs (renamed from src/ForwardAD/DualNumbers/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/Interpreter.hs (renamed from src/Interpreter.hs) | 219 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Accum.hs (renamed from src/Interpreter/Accum.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/AccumOld.hs (renamed from src/Interpreter/AccumOld.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Rep.hs (renamed from src/Interpreter/Rep.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/Language.hs | 423 | ||||
| -rw-r--r-- | src/CHAD/Language/AST.hs (renamed from src/Language/AST.hs) | 130 | ||||
| -rw-r--r-- | src/CHAD/Lemmas.hs (renamed from src/Lemmas.hs) | 2 | ||||
| -rw-r--r-- | src/CHAD/Simplify.hs (renamed from src/Simplify.hs) | 374 | ||||
| -rw-r--r-- | src/CHAD/Simplify/TH.hs (renamed from src/Simplify/TH.hs) | 4 | ||||
| -rw-r--r-- | src/CHAD/Util/IdGen.hs (renamed from src/Util/IdGen.hs) | 2 | ||||
| -rw-r--r-- | src/Language.hs | 229 | ||||
| -rw-r--r-- | test-framework/Test/Framework.hs | 468 | ||||
| -rw-r--r-- | test/Main.hs | 273 |
57 files changed, 6281 insertions, 2845 deletions
@@ -1,3 +1,5 @@ dist-newstyle/ cabal.project.local .ccls-cache/ + +compile_diagnostics_watcher.sh diff --git a/.hlint.yaml b/.hlint.yaml new file mode 100644 index 0000000..7ec649a --- /dev/null +++ b/.hlint.yaml @@ -0,0 +1,17 @@ +- ignore: {name: "Avoid lambda"} +- ignore: {name: "Avoid lambda using `infix`"} +- ignore: {name: "Collapse lambdas"} +- ignore: {name: "Eta reduce"} +- ignore: {name: "Evaluate"} +- ignore: {name: "Redundant $"} +- ignore: {name: "Redundant lambda"} +- ignore: {name: "Use bimap"} +- ignore: {name: "Use camelCase"} +- ignore: {name: "Use const"} +- ignore: {name: "Use forM_"} +- ignore: {name: "Use newtype instead of data"} +- ignore: {name: "Use record patterns"} +- ignore: {name: "Use tuple-section"} +- ignore: {name: "Use unless"} +- ignore: {name: "Use unwords"} +- ignore: {name: "Use void"} diff --git a/README.md b/README.md new file mode 100644 index 0000000..585e3b0 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# chad-fast + +This is work-in-progress research software. If you are interested in techniques +or ideas in this repository, please do [contact me](https://tomsmeding.com), +I'm always happy to chat. diff --git a/bench/Main.hs b/bench/Main.hs index 358ba31..1e8f6f3 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -1,11 +1,14 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where @@ -16,26 +19,31 @@ import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) -import AST -import AST.UnMonoid -import Array -import qualified CHAD (defaultConfig) -import CHAD.Top -import CHAD.Types -import Compile -import Data -import Example -import Example.GMM -import Example.Types -import Interpreter.Rep -import Simplify +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Array +import CHAD.Compile +import CHAD.Data +import CHAD.Drev qualified as CHAD (defaultConfig) +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example +import CHAD.Example.GMM +import CHAD.Example.Types +import CHAD.Interpreter.Rep +import CHAD.Simplify gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) gradCHAD config term = - compile knownEnv $ - simplifyFix $ unMonoid $ simplifyFix $ - ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term + compileStderr knownEnv $ + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ + ELet ext (EConst ext STF64 1.0) $ + chad' config knownEnv $ + simplifyFix term type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where @@ -93,18 +101,19 @@ makeGMMInputs = accumConfig :: CHADConfig accumConfig = chcSetAccum CHAD.defaultConfig -main :: IO () -main = defaultMain - [env (return makeNeuralInputs) $ \inputs -> bgroup "neural" - [env (gradCHAD CHAD.defaultConfig neural) $ \fun -> - bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig neural) $ \fun -> - bench "accum" (nfAppIO fun inputs) - ] - ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm" - [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun -> +bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env)))) + => String -> Ex env R -> SList Value env -> Benchmark +bgroupDefaultAccum name term inputs = + bgroup name + [env (gradCHAD CHAD.defaultConfig term) $ \fun -> bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun -> + ,env (gradCHAD accumConfig term) $ \fun -> bench "accum" (nfAppIO fun inputs) ] + +main :: IO () +main = defaultMain + [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural + ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False) + ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil) ] diff --git a/chad-fast.cabal b/chad-fast.cabal index b0ed639..1eef3ed 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -10,44 +10,49 @@ build-type: Simple library exposed-modules: - Analysis.Identity - Array - AST - AST.Accum - AST.Bindings - AST.Count - AST.Env - AST.Pretty - AST.SplitLets - AST.Types - AST.UnMonoid - AST.Weaken - AST.Weaken.Auto - CHAD - CHAD.Accum - CHAD.EnvDescr - CHAD.Top - CHAD.Types - CHAD.Types.ToTan - Compile - Compile.Exec - Data - Data.VarMap - Example - Example.GMM - Example.Types - ForwardAD - ForwardAD.DualNumbers - ForwardAD.DualNumbers.Types - Interpreter - -- Interpreter.AccumOld - Interpreter.Rep - Language - Language.AST - Lemmas - Simplify - Simplify.TH - Util.IdGen + -- default ghci module on top + CHAD.Example + + CHAD.Analysis.Identity + CHAD.APIv1 + CHAD.Array + CHAD.AST + CHAD.AST.Accum + CHAD.AST.Bindings + CHAD.AST.Count + CHAD.AST.Env + CHAD.AST.Pretty + CHAD.AST.Sparse + CHAD.AST.Sparse.Types + CHAD.AST.SplitLets + CHAD.AST.Types + CHAD.AST.UnMonoid + CHAD.AST.Weaken + CHAD.AST.Weaken.Auto + CHAD.Compile + CHAD.Compile.Exec + CHAD.Data + CHAD.Data.VarMap + CHAD.Drev + CHAD.Drev.Accum + CHAD.Drev.EnvDescr + CHAD.Drev.Top + CHAD.Drev.Types + CHAD.Drev.Types.ToTan + CHAD.Example.GMM + CHAD.Example.Types + CHAD.ForwardAD + CHAD.ForwardAD.DualNumbers + CHAD.ForwardAD.DualNumbers.Types + CHAD.Interpreter + -- CHAD.Interpreter.AccumOld + CHAD.Interpreter.Rep + CHAD.Language + CHAD.Language.AST + CHAD.Lemmas + CHAD.Simplify + CHAD.Simplify.TH + CHAD.Util.IdGen other-modules: build-depends: base >= 4.19 && < 4.21, @@ -81,7 +86,11 @@ library test-framework exposed-modules: Test.Framework build-depends: base, + ansi-terminal, + concurrent-output, hedgehog, + pqueue, + stm, time, transformers hs-source-dirs: test-framework @@ -96,7 +105,6 @@ test-suite test test-framework, base, containers, - dependent-map, hedgehog, text, transformers, diff --git a/example/Main.hs b/example/Main.hs index 6c36857..28cb7e8 100644 --- a/example/Main.hs +++ b/example/Main.hs @@ -1,6 +1,6 @@ module Main where -import Example +import CHAD.Example main :: IO () diff --git a/rules/.gitignore b/rules/.gitignore index ef5f7e5..2a45cd8 100644 --- a/rules/.gitignore +++ b/rules/.gitignore @@ -2,3 +2,4 @@ *.log *.out *.pdf +rules-?.jpg diff --git a/src/AST/Count.hs b/src/AST/Count.hs deleted file mode 100644 index 0c682c6..0000000 --- a/src/AST/Count.hs +++ /dev/null @@ -1,169 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Count where - -import Data.Functor.Const -import GHC.Generics (Generic, Generically(..)) - -import AST -import AST.Env -import Data - - -data Count = Zero | One | Many - deriving (Show, Eq, Ord) - -instance Semigroup Count where - Zero <> n = n - n <> Zero = n - _ <> _ = Many -instance Monoid Count where - mempty = Zero - -data Occ = Occ { _occLexical :: Count - , _occRuntime :: Count } - deriving (Eq, Generic) - deriving (Semigroup, Monoid) via Generically Occ - -instance Show Occ where - showsPrec d (Occ l r) = showParen (d > 10) $ - showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r - --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) - --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l Zero) = Occ l Zero -scaleMany (Occ l _) = Occ l Many - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = - getConst . occCountGeneral - (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) - (\(Const o) -> Const o) - (\(Const o1) (Const o2) -> Const (o1 <||> o2)) - (\(Const o) -> Const (scaleMany o)) - - -data OccEnv env where - OccEnd :: OccEnv env -- not necessarily top! - OccPush :: OccEnv env -> Occ -> OccEnv (t : env) - -instance Semigroup (OccEnv env) where - OccEnd <> e = e - e <> OccEnd = e - OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') - -instance Monoid (OccEnv env) where - mempty = OccEnd - -onehotOccEnv :: Idx env t -> Occ -> OccEnv env -onehotOccEnv IZ v = OccPush OccEnd v -onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty - -(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env -OccEnd <||>! e = e -e <||>! OccEnd = e -OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') - -scaleManyOccEnv :: OccEnv env -> OccEnv env -scaleManyOccEnv OccEnd = OccEnd -scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) - -occEnvPop :: OccEnv (t : env) -> OccEnv env -occEnvPop (OccPush o _) = o -occEnvPop OccEnd = OccEnd - -occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv - -occCountGeneral :: forall r env t x. - (forall env'. Monoid (r env')) - => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot - -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env'. r env' -> r env' -> r env') -- ^ alternation - -> (forall env'. r env' -> r env') -- ^ scale-many - -> Expr x env t -> r env -occCountGeneral onehot unpush alter many = go WId - where - go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' - go w = \case - EVar _ _ i -> onehot w i (Occ One One) - ELet _ rhs body -> re rhs <> re1 body - EPair _ a b -> re a <> re b - EFst _ e -> re e - ESnd _ e -> re e - ENil _ -> mempty - EInl _ _ e -> re e - EInr _ _ e -> re e - ECase _ e a b -> re e <> (re1 a `alter` re1 b) - ENothing _ _ -> mempty - EJust _ e -> re e - EMaybe _ a b e -> re a <> re1 b <> re e - ELNil _ _ _ -> mempty - ELInl _ _ e -> re e - ELInr _ _ e -> re e - ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c) - EConstArr{} -> mempty - EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c - ESum1Inner _ e -> re e - EUnit _ e -> re e - EReplicate1Inner _ a b -> re a <> re b - EMaximum1Inner _ e -> re e - EMinimum1Inner _ e -> re e - EConst{} -> mempty - EIdx0 _ e -> re e - EIdx1 _ a b -> re a <> re b - EIdx _ a b -> re a <> re b - EShape _ e -> re e - EOp _ _ e -> re e - ECustom _ _ _ _ _ _ _ a b -> re a <> re b - ERecompute _ e -> re e - EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a b e -> re a <> re b <> re e - EZero _ _ e -> re e - EPlus _ _ a b -> re a <> re b - EOneHot _ _ _ a b -> re a <> re b - EError{} -> mempty - where - re :: Monoid (r env') => Expr x env' t'' -> r env' - re = go w - - re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' - re1 = unpush . go (WSink .> w) - - -deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil OccEnd k = k SETop -deleteUnused (_ `SCons` env) OccEnd k = - deleteUnused env OccEnd $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = - deleteUnused env occenv $ \sub -> - case count of Zero -> k (SENo sub) - _ -> k (SEYes sub) - -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t -unsafeWeakenWithSubenv = \sub -> - subst (\x t i -> case sinkViaSubenv i sub of - Just i' -> EVar x t i' - Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") - where - sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYes _) = Just IZ - sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub - sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs deleted file mode 100644 index 4f34166..0000000 --- a/src/AST/Env.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Env where - -import AST.Weaken -import Data - - --- | @env'@ is a subset of @env@: each element of @env@ is either included in --- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') - SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') - -subList :: SList f env -> Subenv env env' -> SList f env' -subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) -subList (SCons _ xs) (SENo sub) = subList xs sub - -subenvAll :: SList f env -> Subenv env env -subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) - -subenvNone :: SList f env -> Subenv env '[] -subenvNone SNil = SETop -subenvNone (SCons _ env) = SENo (subenvNone env) - -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} - -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 -subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) -subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) - -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') -subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) -subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) - -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 -sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub -sinkWithSubenv (SENo sub) = sinkWithSubenv sub - -wUndoSubenv :: Subenv env env' -> env' :> env -wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) -wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs deleted file mode 100644 index ac4d733..0000000 --- a/src/AST/UnMonoid.hs +++ /dev/null @@ -1,145 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where - -import AST -import Data - - --- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them --- into their concrete implementations. -unMonoid :: Ex env t -> Ex env t -unMonoid = \case - EZero _ t e -> zero t e - EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) - - EVar _ t i -> EVar ext t i - ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) - EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) - EFst _ e -> EFst ext (unMonoid e) - ESnd _ e -> ESnd ext (unMonoid e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext t (unMonoid e) - EInr _ t e -> EInr ext t (unMonoid e) - ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) - ENothing _ t -> ENothing ext t - EJust _ e -> EJust ext (unMonoid e) - EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) - ELNil _ t1 t2 -> ELNil ext t1 t2 - ELInl _ t e -> ELInl ext t (unMonoid e) - ELInr _ t e -> ELInr ext t (unMonoid e) - ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) - EConstArr _ n t x -> EConstArr ext n t x - EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) - EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) - ESum1Inner _ e -> ESum1Inner ext (unMonoid e) - EUnit _ e -> EUnit ext (unMonoid e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) - EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) - EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) - EConst _ t x -> EConst ext t x - EIdx0 _ e -> EIdx0 ext (unMonoid e) - EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) - EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) - EShape _ e -> EShape ext (unMonoid e) - EOp _ op e -> EOp ext op (unMonoid e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - ERecompute _ e -> ERecompute ext (unMonoid e) - EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) - EError _ t s -> EError ext t s - -zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t -zero SMTNil _ = ENil ext -zero (SMTPair t1 t2) e = - ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) - (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) -zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) -zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e -zero (SMTScal t) _ = case t of - STI32 -> EConst ext STI32 0 - STI64 -> EConst ext STI64 0 - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - -plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -plus SMTNil _ _ = ENil ext -plus (SMTPair t1 t2) a b = - let t = STPair (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) -plus (SMTLEither t1 t2) a b = - let t = STLEither (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - ELCase ext (EVar ext t (IS IZ)) - (EVar ext t IZ) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) - (EError ext t "plus l+r")) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (EError ext t "plus r+l") - (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) -plus (SMTMaybe t) a b = - ELet ext b $ - EMaybe ext - (EVar ext (STMaybe (fromSMTy t)) IZ) - (EJust ext - (EMaybe ext - (EVar ext (fromSMTy t) IZ) - (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) - (weakenExpr WSink a) -plus (SMTArr _ t) a b = - ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - a b -plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) - -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t -onehot typ topprj idx arg = case (typ, topprj) of - (_, SAPHere) -> - ELet ext arg $ - EVar ext (fromSMTy typ) IZ - - (SMTPair t1 t2, SAPFst prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t1 in - EPair ext (EVar ext toh IZ) - (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) - - (SMTPair t1 t2, SAPSnd prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t2 in - EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) - (EVar ext toh IZ) - - (SMTLEither t1 t2, SAPLeft prj) -> - ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) - (SMTLEither t1 t2, SAPRight prj) -> - ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) - - (SMTMaybe t1, SAPJust prj) -> - EJust ext (onehot t1 prj idx arg) - - (SMTArr n t1, SAPArrIdx prj) -> - let tidx = tTup (sreplicate n tIx) - in ELet ext idx $ - EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ - zero t1 (EVar ext (tZeroInfo t1) IZ)) diff --git a/src/CHAD.hs b/src/CHAD.hs deleted file mode 100644 index df792ce..0000000 --- a/src/CHAD.hs +++ /dev/null @@ -1,1159 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module CHAD ( - drev, - freezeRet, - CHADConfig(..), - defaultConfig, - Storage(..), - Descr(..), - Select, -) where - -import Data.Functor.Const -import Data.Some -import Data.Type.Bool (If) -import Data.Type.Equality (type (==), testEquality) -import GHC.Stack (HasCallStack) - -import Analysis.Identity (ValId(..), validSplitEither) -import AST -import AST.Bindings -import AST.Count -import AST.Env -import AST.Weaken.Auto -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap -import Data.VarMap (VarMap) -import Lemmas - - ------------------------------- TAPES AND BINDINGS ------------------------------ - -type family Tape binds where - Tape '[] = TNil - Tape (t : ts) = TPair t (Tape ts) - -tapeTy :: SList STy binds -> STy (Tape binds) -tapeTy SNil = STNil -tapeTy (SCons t ts) = STPair t (tapeTy ts) - -bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds - -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollectTape BTop SETop _ = ENil ext -bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w = - EPair ext (EVar ext t (w @> IZ)) - (bindingsCollectTape binds sub (w .> WSink)) -bindingsCollectTape (BPush binds _) (SENo sub) w = - bindingsCollectTape binds sub (w .> WSink) - --- In order from large to small: i.e. in reverse order from what we want, --- because in a Bindings, the head of the list is the bottom-most entry. -type family TapeUnfoldings binds where - TapeUnfoldings '[] = '[] - TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts - -type family Reverse l where - Reverse '[] = '[] - Reverse (t : ts) = Append (Reverse ts) '[t] - --- An expression that is always 'snd' -data UnfExpr env t where - UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t - -fromUnfExpr :: UnfExpr env t -> Ex env t -fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) - --- - A bunch of 'snd' expressions taking us from knowing that there's a --- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix --- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in --- the environment. --- - In the extended environment, another bunch of let bindings (these are --- 'fst' expressions, but no need to know that statically) that project the --- fsts out of what we introduced above, one for each type in 'ts'. -data Reconstructor env ts = - Reconstructor - (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) - (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) - -ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) -ssnoc SNil a = SCons a SNil -ssnoc (SCons t ts) a = SCons t (ssnoc ts a) - -sreverse :: SList f ts -> SList f (Reverse ts) -sreverse SNil = SNil -sreverse (SCons t ts) = ssnoc (sreverse ts) t - -stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) -stapeUnfoldings SNil = SNil -stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) - --- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. -shiftUnfolder - :: STy t - -> SList STy ts - -> Bindings UnfExpr (Tape ts : env) list - -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) -shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) -shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = - -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order - -- to expand an 'Append' in the types so that things simplify just enough. - -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list - -- of bindings produced by 'b'. We want to conclude from this that - -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know - -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after - -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. - BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t - BPush{} -> UnfExSnd itemTy t) - -growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) -growRecon t ts (Reconstructor unfbs bs) - | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) - , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) - , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env - = Reconstructor - (shiftUnfolder t ts unfbs) - -- Add a 'fst' at the bottom of the builder stack. - -- First we have to weaken most of 'bs' to skip one more binding in the - -- unfolder stack above it. - (BPush (fst (weakenBindings weakenExpr - (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) - (WSink :: env :> (Tape (t : ts) : env))) bs)) - (t - ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ - wSinks @(Tape (t : ts) : env) - (sappend ts - (sappend (sappend (sreverse (stapeUnfoldings ts)) - (SCons (tapeTy ts) SNil)) - SNil)) - @> IZ)) - -buildReconstructor :: SList STy ts -> Reconstructor env ts -buildReconstructor SNil = Reconstructor BTop BTop -buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) - --- STRATEGY FOR reconstructBindings --- --- binds = [] --- e : () --- --- binds = [c] --- e : (c, ()) --- x0 = snd x1 : () --- y1 = fst e : c --- --- binds = [b, c] --- e : (b, (c, ())) --- x1 = snd e : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- --- binds = [a, b, c] --- e : (a, (b, (c, ()))) --- x2 = snd e : (b, (c, ())) --- x1 = snd x2 : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- y3 = fst x3 : a - --- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all --- the things in the list 'binds', we want to create a let stack that extracts --- all values from that tuple and in effect "restores" the environment --- described by 'binds'. The idea is that elsewhere, we took a slice of the --- environment and saved it all in a tuple to be restored later. We --- incidentally also add a bunch of additional bindings, namely 'Reverse --- (TapeUnfoldings binds)', so the calling code just has to skip those in --- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) - -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) - ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = - let Reconstructor unf build = buildReconstructor binds - in (fst $ weakenBindings weakenExpr (WIdx tape) - (bconcat (mapBindings fromUnfExpr unf) build) - ,sreverse (stapeUnfoldings binds)) - - ----------------------------------- DERIVATIVES --------------------------------- - -d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) -d1op (OAdd t) e = EOp ext (OAdd t) e -d1op (OMul t) e = EOp ext (OMul t) e -d1op (ONeg t) e = EOp ext (ONeg t) e -d1op (OLt t) e = EOp ext (OLt t) e -d1op (OLe t) e = EOp ext (OLe t) e -d1op (OEq t) e = EOp ext (OEq t) e -d1op ONot e = EOp ext ONot e -d1op OAnd e = EOp ext OAnd e -d1op OOr e = EOp ext OOr e -d1op OIf e = EOp ext OIf e -d1op ORound64 e = EOp ext ORound64 e -d1op OToFl64 e = EOp ext OToFl64 e -d1op (ORecip t) e = EOp ext (ORecip t) e -d1op (OExp t) e = EOp ext (OExp t) e -d1op (OLog t) e = EOp ext (OLog t) e -d1op (OIDiv t) e = EOp ext (OIDiv t) e -d1op (OMod t) e = EOp ext (OMod t) e - --- | Both primal and dual must be duplicable expressions -data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) - | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) - -d2op :: SOp a t -> D2Op a t -d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) - OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) - ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 - OToFl64 -> Linear $ \_ -> ENil ext - ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) - OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) - OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - where - d2opUnArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TScal a) t) - -> D2Op (TScal a) t - d2opUnArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENil ext - STI64 -> Linear $ \_ -> ENil ext - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENil ext - - d2opBinArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) - -> D2Op (TPair (TScal a) (TScal a)) t - d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - - floatingD2 :: ScalIsFloating a ~ True - => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r - floatingD2 STF32 k = k - floatingD2 STF64 k = k - - integralD2 :: ScalIsIntegral a ~ True - => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r - integralD2 STI32 k = k - integralD2 STI64 k = k - -desD1E :: Descr env sto -> SList STy (D1E env) -desD1E = d1e . descrList - --- d1W :: env :> env' -> D1E env :> D1E env' --- d1W WId = WId --- d1W WSink = WSink --- d1W (WCopy w) = WCopy (d1W w) --- d1W (WPop w) = WPop (d1W w) --- d1W (WThen u w) = WThen (d1W u) (d1W w) - -conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) -conv1Idx IZ = IZ -conv1Idx (IS i) = IS (conv1Idx i) - -data Idx2 env sto t - = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - | Idx2Me (Idx (Select env sto "merge") t) - | Idx2Di (Idx (Select env sto "discr") t) - -conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t -conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ -conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ -conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ -conv2Idx (DPush des (_, _, SAccum)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SMerge)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me (IS j) - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SDiscr)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di (IS j) -conv2Idx DTop i = case i of {} - - ------------------------------------- MONOIDS ----------------------------------- - -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) - - ------------------------------------- SUBENVS ----------------------------------- - -subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) - -> r) - -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> - ELet ext e2 $ - EPair ext (pl (weakenExpr WSink e1) - (EFst ext (EVar ext (typeOf e2) IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext (d2M t) - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ))) - -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) - -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" - - ---------------------------------- ACCUMULATORS --------------------------------- - -fromArrayValId :: Maybe (ValId t) -> Maybe Int -fromArrayValId (Just (VIArr i _)) = Just i -fromArrayValId _ = Nothing - -accumPromote :: forall dt env sto proxy r. - proxy dt - -> Descr env sto - -> (forall stoRepl envPro. - (Select env stoRepl "merge" ~ '[]) - => Descr env stoRepl - -- ^ A revised environment description that switches - -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. - -> SList STy envPro - -- ^ New entries on top of the original dual environment, - -- that house the accumulators for the promoted arrays in - -- the original environment. - -> Subenv (Select env sto "merge") envPro - -- ^ The promoted entries were merge entries in the - -- original environment. - -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) - -- ^ All entries that were accumulators are still - -- accumulators. - -> VarMap Int (D2AcE (Select env stoRepl "accum")) - -- ^ Accumulator map for _only_ the the newly allocated - -- accumulators. - -> (forall shbinds. - SList STy shbinds - -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) - -- ^ A weakening that converts a computation in the - -- revised environment to one in the original environment - -- extended with some accumulators. - -> r) - -> r -accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of - -- Accumulators are left as-is - SAccum -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - envpro - prosub - (SEYes accrevsub) - (VarMap.sink1 accumMap) - (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) - (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) - (#pro :++: #d :++: #shb :++: #acc :++: #tl) - .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) - (#d :++: #shb :++: #acc :++: #tl) - (#acc :++: (#d :++: #shb :++: #tl))) - - SMerge -> case t of - -- Discrete values are left as-is - _ | isDiscrete t -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - (SENo prosub) - accrevsub - accumMap' - wf - - -- Values with "merge" storage are promoted to an accumulator in envPro - _ -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - (t `SCons` envpro) - (SEYes prosub) - (SENo accrevsub) - (let accumMap' = VarMap.sink1 accumMap - in case fromArrayValId vid of - Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' - Nothing -> accumMap') - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Discrete values are left as-is, nothing to do - SDiscr -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - prosub - accrevsub - accumMap - wf - where - isDiscrete :: STy t' -> Bool - isDiscrete = \case - STNil -> True - STPair a b -> isDiscrete a && isDiscrete b - STEither a b -> isDiscrete a && isDiscrete b - STLEither a b -> isDiscrete a && isDiscrete b - STMaybe a -> isDiscrete a - STArr _ a -> isDiscrete a - STScal st -> case st of - STI32 -> True - STI64 -> True - STF32 -> False - STF64 -> False - STBool -> True - STAccum{} -> False - - ----------------------------- RETURN TRIPLE FROM CHAD --------------------------- - -data Ret env0 sto t = - forall shbinds tapebinds env0Merge. - Ret (Bindings Ex (D1E env0) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (Ex (Append shbinds (D1E env0)) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (Ret env0 sto t) - -data RetPair env0 sto env shbinds tapebinds t = - forall env0Merge. - RetPair (Ex (Append shbinds env) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds tapebinds t) - -data Rets env0 sto env list = - forall shbinds tapebinds. - Rets (Bindings Ex env shbinds) - (Subenv shbinds tapebinds) - (SList (RetPair env0 sto env shbinds tapebinds) list) -deriving instance Show (Rets env0 sto env list) - -weakenRetPair :: SList STy shbinds -> env :> env' - -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t -weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 - -weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list -weakenRets w (Rets binds tapesub list) = - let (binds', _) = weakenBindings weakenExpr w binds - in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) - -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. - Descr env0 sto - -> SList f b1 -> SList f b2 - -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 - -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t - -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) - | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (autoWeak - (#d (auto1 @(D2 t)) - &. #t2 (subList b2 subtape2) - &. #t1 (subList b1 subtape1) - &. #tl (d2ace (select SAccum descr))) - (#d :++: (#t2 :++: #tl)) - (#d :++: ((#t2 :++: #t1) :++: #tl))) - d) - -retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list -retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) - | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs - <- weakenRets (sinkWithBindings b) (retConcat descr list) - , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) - , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) - = Rets (bconcat b binds) - (subenvConcat subtape subtape2) - (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) - sub - (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) - (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) - subtape subtape2) - pairs)) - -freezeRet :: Descr env sto - -> Ret env sto t - -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = - let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 - e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 - in letBinds e0' $ - EPair ext - (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr)) - (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) - (#shbinds :++: #d :++: #d2ace :++: #tl)) - e2') $ - expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) - - ----------------------------- THE CHAD TRANSFORMATION --------------------------- - -drev :: forall env sto t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Expr ValId env t -> Ret env sto t -drev des accumMap = \case - EVar _ t i -> - case conv2Idx des i of - Idx2Ac accI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) - - Idx2Me tupI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (select SMerge des) tupI) - (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) - - Idx2Di _ -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (ENil ext) - - ELet _ (rhs :: Expr _ _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs - , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> - subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) - (weakenExpr wbody0' body1) - subBoth - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) - &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #tl) - (#d :++: (#body :++: #rhs) :++: #tl)) - body2) $ - ELet ext - (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ - plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) - (EFst ext (EVar ext bodyResType (IS IZ)))) - - EPair _ a b - | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil - , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> - subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> - Ret binds - subtape - (EPair ext a1 b1) - subBoth - (EMaybe ext - (zeroTup (subList (select SMerge des) subBoth)) - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) - (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) - - EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> - Ret e0 - subtape - (EFst ext e1) - sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $ - weakenExpr (WCopy WSink) e2) - - ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> - Ret e0 - subtape - (ESnd ext e1) - sub - (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $ - weakenExpr (WCopy WSink) e2) - - ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) - - EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EInl ext (d1 t2) e1) - sub - (ELCase ext - (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ) - (zeroTup (subList (select SMerge des) sub)) - (weakenExpr (WCopy WSink) e2) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) - - EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EInr ext (d1 t1) e1) - sub - (ELCase ext - (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ) - (zeroTup (subList (select SMerge des) sub)) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy WSink) e2)) - - ECase _ e (a :: Expr _ _ t) b - | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e - , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge - , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , let (bindids1, bindids2) = validSplitEither (extOf e) - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b - , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) - , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollectTape a0 subtapeA - , let collectB = bindingsCollectTape b0 subtapeB - , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) - , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 - -> - subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> - subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in - Ret (e0 `BPush` - (tPrimal, - ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) - (SEYes subtapeE) - (EFst ext (EVar ext tPrimal IZ)) - subOut - (ELet ext - (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ - in letBinds rebinds $ - ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #ta0 (subList (bindingsBinds a0) subtapeA) - &. #prea0 prerebinds - &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #ta0 :++: #tl) - (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) - a2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (ELInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ - in letBinds rebinds $ - ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tb0 (subList (bindingsBinds b0) subtapeB) - &. #preb0 prerebinds - &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #tb0 :++: #tl) - (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) - b2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (ELInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ - ELet ext - (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ - plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) - - EConst _ t val -> - Ret BTop - SETop - (EConst ext t val) - (subenvNone (select SMerge des)) - (ENil ext) - - EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - case d2op op of - Linear d2opfun -> - Ret e0 - subtape - (d1op op e1) - sub - (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy WSink) e2)) - Nonlinear d2opfun -> - Ret (e0 `BPush` (d1 (typeOf e), e1)) - (SEYes subtape) - (d1op op $ EVar ext (d1 (typeOf e)) IZ) - sub - (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) - (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - - ECustom _ _ _ storety _ pr du a b - -- allowed to ignore a2 because 'a' is the part of the input that is inactive - | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> - Ret (binds `BPush` (typeOf a1, a1) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYes (SENo (SENo (SENo subtape)))) - (EFst ext (EVar ext (typeOf pr) (IS IZ))) - bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink)) b2) - - -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal - ERecompute _ e -> - deleteUnused (descrList des) (occCountAll e) $ \usedSub -> - let smallE = unsafeWeakenWithSubenv usedSub e in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 -> - Ret (collectBindings (desD1E des) subD1eUsed) - (subenvAll (desD1E usedDes)) - (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1) - (subenvCompose subMergeUsed sub) - (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ - weakenExpr - (autoWeak (#d (auto1 @(D2 t)) - &. #shbinds (bindingsBinds e0) - &. #tape (subList (bindingsBinds e0) subtape) - &. #d1env (desD1E usedDes) - &. #tl' (d2ace (select SAccum usedDes)) - &. #tl (d2ace (select SAccum des))) - (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) - (#shbinds :++: #d :++: #d1env :++: #tl)) - e2) - } - - EError _ t s -> - Ret BTop - SETop - (EError ext (d1 t) s) - (subenvNone (select SMerge des)) - (ENil ext) - - EConstArr _ n t val -> - Ret BTop - SETop - (EConstArr ext n t val) - (subenvNone (select SMerge des)) - (ENil ext) - - EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result - , let eltty = typeOf orige - , shty :: STy shty <- tTup (sreplicate ndim tIx) - , Refl <- indexTupD1Id ndim -> - deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> - case assertSubenvEmpty sub of { Refl -> - let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollectTape e0 subtapeE in - Ret (BTop `BPush` (shty, letBinds she0 she1) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #d1env) - in EPair ext (weakenExpr w e1) (collectexpr w))) - `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYes (SENo (SEYes SETop))) - (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - EMaybe ext - (zeroTup envPro) - (ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim (D2 eltty))) - &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - (EVar ext (d2 (STArr ndim eltty)) IZ)) - }} - - EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EUnit ext e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) - - EReplicate1Inner _ en e - -- We're allowed to ignore en2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil - , let STArr ndim eltty = typeOf e -> - Ret binds - subtape - (EReplicate1Inner ext en1 e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (ezeroD2 eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) - - EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STArr _ t <- typeOf e -> - Ret e0 - subtape - (EIdx0 ext e1) - sub - (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ - weakenExpr (WCopy WSink) e2) - - EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" - {- - EIdx1 _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) - (SEYes (SENo subtape)) - (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) - (weakenExpr (WSink .> WSink) ei1)) - sub - (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - -} - - EIdx _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr n eltty <- typeOf e - , Refl <- indexTupD1Id n - , Refl <- lemZeroInfoD2 eltty - , let tIxN = tTup (sreplicate n tIx) -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYes (SEYes (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - - EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e - , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) - (ENil ext) - - ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STArr (SS n) t <- typeOf e -> - Ret (e0 `BPush` (STArr (SS n) t, e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) - (SEYes (SENo subtape)) - (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) - - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e - - -- These should be the next to be implemented, I think - EFold1Inner{} -> err_unsupported "EFold1Inner" - - ENothing{} -> err_unsupported "ENothing" - EJust{} -> err_unsupported "EJust" - EMaybe{} -> err_unsupported "EMaybe" - ELNil{} -> err_unsupported "ELNil" - ELInl{} -> err_unsupported "ELInl" - ELInr{} -> err_unsupported "ELInr" - ELCase{} -> err_unsupported "ELCase" - - EWith{} -> err_accum - EAccum{} -> err_accum - EZero{} -> err_monoid - EPlus{} -> err_monoid - EOneHot{} -> err_monoid - - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - err_unsupported s = error $ "CHAD: unsupported " ++ s - - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYes (SEYes subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (ezeroD2 t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) - -data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) - -data RetScoped env0 sto a s t = - forall shbinds tapebinds env0Merge. - RetScoped - (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - -- ^ merge contributions to the _enclosing_ merge environment - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a)))) - -- ^ the merge contributions, plus the cotangent to the argument - -- (if there is any) -deriving instance Show (RetScoped env0 sto a s t) - -drevScoped :: forall a s env sto t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> STy a -> Storage s -> Maybe (ValId a) - -> Expr ValId (a : env) t - -> RetScoped env sto a s t -drevScoped des accumMap argty argsto argids expr = case argsto of - SMerge - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - case sub of - SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) - - SAccum - | Just (VIArr i _) <- argids - , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap - , Just Refl <- testEquality foundTy (STAccum (d2M argty)) - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> - RetScoped e0 subtape e1 sub $ - let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in - ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) - -- Our contribution to the binding's cotangent _here_ is - -- zero, because we're contributing to an earlier binding - -- of the same value instead. - (EPair ext e2 (ezeroD2 argty)) - - | let accumMap' = case argids of - Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) - _ -> VarMap.sink1 accumMap - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> - RetScoped e0 subtape e1 sub $ - EWith ext (d2M argty) (ezeroD2 argty) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) - e2 - - SDiscr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - RetScoped e0 subtape e1 sub e2 diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs new file mode 100644 index 0000000..73d1580 --- /dev/null +++ b/src/CHAD/APIv1.hs @@ -0,0 +1,178 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.APIv1 ( + -- * Expressions and types + Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..), + + -- * Reverse derivatives (Fast CHAD) + vjp, vjp', + D2, D2E, Tup, + CHADConfig(..), + + -- ** Primal type transform + -- | The primal type transform is only important when working with special + -- operations like 'CHAD.Language.custom'. + D1, + + -- * Forward derivatives (dual numbers) + jvp, jvpDN, + Tan, TanS, DN, DNS, DNE, + + -- * Working with expressions + interpret, interpret1, + compile, compile1, + fullSimplify, + SList(..), Value(..), Rep, + KnownEnv(..), KnownTy(..), + SNat(..), +) where + +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Compile qualified as Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter qualified as Interpreter +import CHAD.Simplify +import CHAD.Interpreter.Rep + + +-- | Compute a reverse derivative: a vector-Jacobian product. The type has been +-- simplified with the assumption that 'D1' is the identity. +vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp = vjp' (chcSetAccum defaultConfig) + +-- | Same as 'vjp', but supply CHAD configuration. +vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp' config term + | Dict <- styKnown (d2 (typeOf term)) = + fullSimplify $ + unMonoid . simplifyFix $ -- need to merge onehots and accums for unMonoid to do its work + chad' config knownEnv (simplifyFix term) + +jvpDN :: Ex env t -> Ex (DNE env) (DN t) +jvpDN = dfwdDN + +jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t)) +jvp term + | Dict <- styKnown (tanty (knownTy @s)) + = fullSimplify $ + elet (ezipDN knownTy) $ + elet (weakenExpr (WCopy WClosed) (jvpDN term)) $ + eunzipDN (typeOf term) + where + ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s') + ezipDN STNil = ENil ext + ezipDN (STPair a b) = + EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env a)) + (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env b)) + ezipDN (STEither a b) = + ecase (EVar ext (STEither a b) (IS IZ)) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EInl ext (dn b) (ezipDN a)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr")) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl") + (EInr ext (dn a) (ezipDN b))) + ezipDN (STLEither a b) = + elcase (EVar ext (STLEither a b) (IS IZ)) + (ELNil ext (dn a) (dn b)) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN") + (ELInl ext (dn b) (ezipDN a)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr")) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN") + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl") + (ELInr ext (dn a) (ezipDN b))) + ezipDN (STMaybe t) = + emaybe (EVar ext (STMaybe t) (IS IZ)) + (ENothing ext (dn t)) + (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ)) + (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN") + (EJust ext (ezipDN t))) + ezipDN (STArr n t) = + ezipWith (ezipDN t) + (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ) + ezipDN (STScal st) = case st of + STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ) + STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ) + STI32 -> EVar ext (STScal STI32) (IS IZ) + STI64 -> EVar ext (STScal STI64) (IS IZ) + STBool -> EVar ext (STScal STBool) (IS IZ) + ezipDN STAccum{} = error "jvp: Accumulators not supported in source program" + + eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t')) + eunzipDN STNil = EPair ext (ENil ext) (ENil ext) + eunzipDN (STPair a b) = + eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 -> + eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 -> + EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2) + eunzipDN (STEither a b) = + ecase (EVar ext (STEither (dn a) (dn b)) IZ) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (EInl ext b a1) (EInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (EInr ext a b1) (EInr ext (tanty a) b2)) + eunzipDN (STLEither a b) = + elcase (EVar ext (STLEither (dn a) (dn b)) IZ) + (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b))) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2)) + eunzipDN (STMaybe t) = + emaybe (EVar ext (STMaybe (dn t)) IZ) + (EPair ext (ENothing ext t) (ENothing ext (tanty t))) + (eunPair (eunzipDN t) $ \_ e1 e2 -> + EPair ext (EJust ext e1) (EJust ext e2)) + eunzipDN (STArr n t) = + elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $ + EPair ext (emap (EFst ext (evar IZ)) (evar IZ)) + (emap (ESnd ext (evar IZ)) (evar IZ)) + eunzipDN (STScal st) = case st of + STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ + STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ + STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext) + STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext) + STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext) + eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program" + +-- | Interpret an expression in a given environment. +interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t +interpret = Interpreter.interpretOpen False knownEnv + +-- | Special case of 'interpret' for an expression with a single free variable. +interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t +interpret1 x = interpret (Value x `SCons` SNil) + +-- | Compile an expression to C, load the resulting shared object into the +-- program and wrap it in a Haskell function. +compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t)) +compile = Compile.compileStderr knownEnv + +-- | Special case of 'compile' for an expression with a single free variable. +compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t)) +compile1 term = do + f <- Compile.compileStderr knownEnv term + return (\x -> f (Value x `SCons` SNil)) + +-- | Simplify an expression. The 'vjp'/'jvp' functions already do this automatically. +fullSimplify :: KnownEnv env => Ex env t -> Ex env t +fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix diff --git a/src/AST.hs b/src/CHAD/AST.hs index 149cddd..ce9eb20 100644 --- a/src/AST.hs +++ b/src/CHAD/AST.hs @@ -1,14 +1,12 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -16,19 +14,20 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-} -module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where +module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where import Data.Functor.Const import Data.Functor.Identity +import Data.Int (Int64) import Data.Kind (Type) -import Array -import AST.Accum -import AST.Types -import AST.Weaken -import CHAD.Types -import Data +import CHAD.Array +import CHAD.AST.Accum +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types -- General assumption: head of the list (whatever way it is associated) is the @@ -60,12 +59,35 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) + + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative + -> Expr x (TPair t1 t1 : env) (TPair t1 b) + -> Expr x env t1 + -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative + -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -80,6 +102,9 @@ data Expr x env t where -- be backpropagated to; 'a' is the inactive part. The dual field of -- ECustom does not allow a derivative to be generated for 'a', and hence -- none is propagated. + -- No accumulators are allowed inside a, b and tape. This restriction is + -- currently not used very much, so could be relaxed in the future; be sure + -- to check this requirement whenever it is necessary for soundness! ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x [b, a] t -- ^ regular operation -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass @@ -91,13 +116,18 @@ data Expr x env t where ERecompute :: x t -> Expr x env t -> Expr x env t -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t -- interface of abstract monoidal types ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) @@ -109,6 +139,14 @@ data Expr x env t where EError :: x a -> STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) +-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full +-- 'Expr' type is parametrised on an indexed type of "additional info" (@x@); +-- 'Ex' sets this to nothing. +-- +-- Construct expressions using the functions in "CHAD.Language". +-- +-- Use 'CHAD.AST.Pretty.pprintExpr' or 'CHAD.AST.Pretty.ppExpr' to inspect +-- expressions. type Ex = Expr (Const ()) ext :: Const () a @@ -200,12 +238,18 @@ typeOf = \case EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) + EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t + EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) + + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -218,9 +262,10 @@ typeOf = \case ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil + EAccum _ _ _ _ _ _ _ -> STNil EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t EPlus _ t _ _ -> fromSMTy t EOneHot _ t _ _ _ -> fromSMTy t @@ -246,12 +291,17 @@ extOf = \case ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x + EMap x _ _ -> x EFold1Inner x _ _ _ _ -> x ESum1Inner x _ -> x EUnit x _ -> x EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x + EReshape x _ _ _ -> x + EZip x _ _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -261,8 +311,9 @@ extOf = \case ECustom x _ _ _ _ _ _ _ _ -> x ERecompute x _ -> x EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x EZero x _ _ -> x + EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x @@ -291,12 +342,17 @@ travExt f = \case ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e EUnit x e -> EUnit <$> f x <*> travExt f e EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b + EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b @@ -306,8 +362,9 @@ travExt f = \case ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 ERecompute x e -> ERecompute <$> f x <*> travExt f e EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 - EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s @@ -349,12 +406,17 @@ subst' f w = \case ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) + EZip x a b -> EZip x (subst' f w a) (subst' f w b) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) @@ -364,8 +426,9 @@ subst' f w = \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s @@ -437,6 +500,16 @@ envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + EFst _ e -> cheapExpr e + ESnd _ e -> cheapExpr e + EUnit _ e -> cheapExpr e + _ -> False + eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) @@ -461,33 +534,26 @@ eidxEq (SS n) a b (eidxEq n (EFst ext (EVar ext ty (IS IZ))) (EFst ext (EVar ext ty IZ))) -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr + | STArr _ t <- typeOf arr + , Dict <- styKnown t + = EMap ext f arr -ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 = - let STArr n t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 + | STArr _ t1 <- typeOf arr1 + , STArr _ t2 <- typeOf arr2 + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) + IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) + IS (IS i) -> EVar ext t (IS i)) + f) + (EZip ext arr1 arr2) ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = - let STArr _ t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 +ezip = EZip ext eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) @@ -503,11 +569,141 @@ eshapeEmpty (SS n) e = (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) -ezeroD2 :: STy t -> Ex env (D2 t) -ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext) +eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) +eshapeConst ShNil = ENil ext +eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) + +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi -- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil -- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea -- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) -- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body + +-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +ecase e a b + | STEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ECase ext e a b + +elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext + +splitSparsePair + :: -- given a sparsity + STy (TPair a b) -> Sparse (TPair a b) t' + -> (forall a' b'. + -- I give you back two sparsities for a and b + Sparse a a' -> Sparse b b' + -- furthermore, I tell you that either your t' is already this (a', b') pair... + -> Either + (t' :~: TPair a' b') + -- or I tell you how to construct a' and b' from t', given an actual t' + (forall r' env. + Idx env t' + -> (forall env'. + (forall c. Ex env' c -> Ex env c) + -> Ex env' a' -> Ex env' b' -> r') + -> r') + -> r) + -> r +splitSparsePair _ SpAbsent k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = + k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = + let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in + k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> + k2 (elet $ + emaybe (EVar ext (STMaybe (applySparse s t)) i) + (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) + (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) + (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = + splitSparsePair t (SpSparse s) $ \s1 s2 eres -> + k s1 s2 $ Right $ \i k2 -> + case eres of + Left refl -> case refl of {} + Right f -> + f IZ $ \wrap e1 e2 -> + k2 (\body -> + elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) + (ENothing ext (applySparse s t)) + (evar IZ)) $ + wrap body) + e1 e2 diff --git a/src/AST/Accum.hs b/src/CHAD/AST/Accum.hs index 03369c8..ea74a95 100644 --- a/src/AST/Accum.hs +++ b/src/CHAD/AST/Accum.hs @@ -1,15 +1,14 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -module AST.Accum where +module CHAD.AST.Accum where -import AST.Types -import CHAD.Types -import Data +import CHAD.AST.Types +import CHAD.Data data AcPrj @@ -35,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) - AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) - AcIdx (APLeft p) (TLEither a b) = AcIdx p a - AcIdx (APRight p) (TLEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = - -- ((index, shapes info), recursive info) +type data AIDense = AID | AIS + +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AIS (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t + acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t @@ -75,19 +92,23 @@ tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. diff --git a/src/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs index 745a93b..c1a1e77 100644 --- a/src/AST/Bindings.hs +++ b/src/CHAD/AST/Bindings.hs @@ -13,12 +13,12 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module AST.Bindings where +module CHAD.AST.Bindings where -import AST -import AST.Env -import Data -import Lemmas +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data +import CHAD.Lemmas -- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. @@ -28,6 +28,10 @@ data Bindings f env binds where deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') infixl `BPush` +bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush b e = b `BPush` (typeOf e, e) +infixl `bpush` + mapBindings :: (forall env' t'. f env' t' -> g env' t') -> Bindings f env binds -> Bindings g env binds mapBindings _ BTop = BTop @@ -42,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) = let (b', w') = weakenBindings wf w b in (BPush b' (t, wf w' x), WCopy w') +weakenBindingsE :: env1 :> env2 + -> Bindings (Expr x) env1 binds + -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) +weakenBindingsE = weakenBindings weakenExpr + weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' weakenOver SNil w = w weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) @@ -69,7 +78,7 @@ collectBindings = \env -> fst . go env WId where go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) go _ _ SETop = (BTop, WId) - go (ty `SCons` env) w (SEYes sub) = + go (ty `SCons` env) w (SEYesR sub) = let (bs, w') = go env (WPop w) sub in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs new file mode 100644 index 0000000..46173d2 --- /dev/null +++ b/src/CHAD/AST/Count.hs @@ -0,0 +1,927 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +module CHAD.AST.Count where + +import Data.Functor.Product +import Data.Some +import Data.Type.Equality +import GHC.Generics (Generic, Generically(..)) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data + + +-- | The monoid operation combines assuming that /both/ branches are taken. +class Monoid a => Occurrence a where + -- | One of the two branches is taken + (<||>) :: a -> a -> a + -- | This code is executed many times + scaleMany :: a -> a + + +data Count = Zero | One | Many + deriving (Show, Eq, Ord) + +instance Semigroup Count where + Zero <> n = n + n <> Zero = n + _ <> _ = Many +instance Monoid Count where + mempty = Zero +instance Occurrence Count where + (<||>) = max + scaleMany Zero = Zero + scaleMany _ = Many + +data Occ = Occ { _occLexical :: Count + , _occRuntime :: Count } + deriving (Eq, Generic) + deriving (Semigroup, Monoid) via Generically Occ + +instance Show Occ where + showsPrec d (Occ l r) = showParen (d > 10) $ + showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r + +instance Occurrence Occ where + Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) + scaleMany (Occ l c) = Occ l (scaleMany c) + + +data Substruc t t' where + -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below + SsFull :: Substruc t t + SsNone :: Substruc t TNil + SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') + SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') + SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') + SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') + SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements + SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') + +pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' +pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) + where SsPair' = SsPair +{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} + +pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' +pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) + where SsArr' = SsArr +{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} + +instance Semigroup (Some (Substruc t)) where + Some SsFull <> _ = Some SsFull + _ <> Some SsFull = Some SsFull + Some SsNone <> s = s + s <> Some SsNone = s + Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) + Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) + Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) + Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) + Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) + Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) +instance Monoid (Some (Substruc t)) where + mempty = Some SsNone + +instance TestEquality (Substruc t) where + testEquality SsFull s = isFull s + testEquality s SsFull = sym <$> isFull s + testEquality SsNone SsNone = Just Refl + testEquality SsNone _ = Nothing + testEquality _ SsNone = Nothing + testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + +isFull :: Substruc t t' -> Maybe (t :~: t') +isFull SsFull = Just Refl +isFull SsNone = Nothing -- TODO: nil? +isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing + +applySubstruc :: Substruc t t' -> STy t -> STy t' +applySubstruc SsFull t = t +applySubstruc SsNone _ = STNil +applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) +applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) +applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) + +applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' +applySubstrucM SsFull t = t +applySubstrucM SsNone _ = SMTNil +applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) +applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) +applySubstrucM _ t = case t of {} + +data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) + | a ~ b => ExMapId + +fromExMap :: ExMap a b -> Ex env a -> Ex env b +fromExMap (ExMap f) = f +fromExMap ExMapId = id + +simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' +simplifySubstruc STNil SsNone = SsFull + +simplifySubstruc _ SsFull = SsFull +simplifySubstruc _ SsNone = SsNone +simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) +simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) +simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) + +-- simplifySubstruc' :: Substruc t t' +-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r +-- simplifySubstruc' SsFull k = k SsFull ExMapId +-- simplifySubstruc' SsNone k = k SsNone ExMapId +-- simplifySubstruc' (SsPair s1 s2) k = +-- simplifySubstruc' s1 $ \s1' f1 -> +-- simplifySubstruc' s2 $ \s2' f2 -> +-- case (s1', s2') of +-- (SsFull, SsFull) -> +-- k SsFull (case (f1, f2) of +-- (ExMapId, ExMapId) -> ExMapId +-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> +-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) +-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) +-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) +-- simplifySubstruc' _ _ = _ + +-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) +-- ssUnpair SsFull = (SsFull, SsFull) +-- ssUnpair SsNone = (SsNone, SsNone) +-- ssUnpair (SsPair a b) = (a, b) + +-- ssUnleft :: Substruc (TEither a b) -> Substruc a +-- ssUnleft SsFull = SsFull +-- ssUnleft SsNone = SsNone +-- ssUnleft (SsEither a _) = a + +-- ssUnright :: Substruc (TEither a b) -> Substruc b +-- ssUnright SsFull = SsFull +-- ssUnright SsNone = SsNone +-- ssUnright (SsEither _ b) = b + +-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a +-- ssUnlleft SsFull = SsFull +-- ssUnlleft SsNone = SsNone +-- ssUnlleft (SsLEither a _) = a + +-- ssUnlright :: Substruc (TLEither a b) -> Substruc b +-- ssUnlright SsFull = SsFull +-- ssUnlright SsNone = SsNone +-- ssUnlright (SsLEither _ b) = b + +-- ssUnjust :: Substruc (TMaybe a) -> Substruc a +-- ssUnjust SsFull = SsFull +-- ssUnjust SsNone = SsNone +-- ssUnjust (SsMaybe a) = a + +-- ssUnarr :: Substruc (TArr n a) -> Substruc a +-- ssUnarr SsFull = SsFull +-- ssUnarr SsNone = SsNone +-- ssUnarr (SsArr a) = a + +-- ssUnaccum :: Substruc (TAccum a) -> Substruc a +-- ssUnaccum SsFull = SsFull +-- ssUnaccum SsNone = SsNone +-- ssUnaccum (SsAccum a) = a + + +type family MapEmpty env where + MapEmpty '[] = '[] + MapEmpty (t : env) = TNil : MapEmpty env + +data OccEnv a env env' where + OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! + OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') + +instance Semigroup a => Semigroup (Some (OccEnv a env)) where + Some OccEnd <> e = e + e <> Some OccEnd = e + Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) + +instance Semigroup a => Monoid (Some (OccEnv a env)) where + mempty = Some OccEnd + +instance Occurrence a => Occurrence (Some (OccEnv a env)) where + Some OccEnd <||> e = e + e <||> Some OccEnd = e + Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) + + scaleMany (Some OccEnd) = Some OccEnd + scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) + +onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) +onehotOccEnv IZ v s = Some (OccPush OccEnd v s) +onehotOccEnv (IS i) v s + | Some env' <- onehotOccEnv i v s + = Some (OccPush env' mempty SsNone) + +occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') +occEnvPop (OccPush e _ s) = (e, s) +occEnvPop OccEnd = (OccEnd, SsNone) + +occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r +occEnvPop' (OccPush e _ s) k = k e s +occEnvPop' OccEnd k = k OccEnd SsNone + +occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) +occEnvPopSome (Some (OccPush e _ _)) = Some e +occEnvPopSome (Some OccEnd) = Some OccEnd + +occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) +occEnvPrj OccEnd _ = mempty +occEnvPrj (OccPush _ o s) IZ = (o, Some s) +occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i + +occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) +occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) +occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) +occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) +occEnvPrjS (OccPush e _ _) (IS i) + | Some (Pair s' i') <- occEnvPrjS e i + = Some (Pair s' (IS i')) + +projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small +projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of + _ | Just Refl <- testEquality topsbig topssmall -> ex + + (SsFull, SsFull) -> ex + (SsNone, SsNone) -> ex + (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" + (_, SsNone) -> + case typeOf ex of + STNil -> ex + _ -> use ex $ ENil ext + + (SsPair s1 s2, SsPair s1' s2') -> + eunPair ex $ \_ e1 e2 -> + EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) + (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex + (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex + + (SsEither s1 s2, SsEither s1' s2') + | STEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in ecase ex + (EInl ext (typeOf e2) e1) + (EInr ext (typeOf e1) e2) + (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex + (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex + + (SsLEither s1 s2, SsLEither s1' s2') + | STLEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in elcase ex + (ELNil ext (typeOf e1) (typeOf e2)) + (ELInl ext (typeOf e2) e1) + (ELInr ext (typeOf e1) e2) + (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex + (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex + + (SsMaybe s1, SsMaybe s1') + | STMaybe t1 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + in emaybe ex + (ENothing ext (typeOf e1)) + (EJust ext e1) + (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex + (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex + + (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex + (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex + (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex + + (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" + (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex + (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex + + +-- | A boolean for each entry in the environment, with the ability to uniformly +-- mask the top part above a certain index. +data EnvMask env where + EMRest :: Bool -> EnvMask env + EMPush :: EnvMask env -> Bool -> EnvMask (t : env) + +envMaskPrj :: EnvMask env -> Idx env t -> Bool +envMaskPrj (EMRest b) _ = b +envMaskPrj (_ `EMPush` b) IZ = b +envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx ex + | Some env <- occCountAll ex + = fst (occEnvPrj env idx) + +occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll ex = occCountX SsFull ex $ \env _ -> Some env + +pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) + where + fullOccEnv :: SList f env -> OccEnv () env env + fullOccEnv SNil = OccEnd + fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull + +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse +-- * k: is passed the actual environment usage of this expression, including +-- occurrence counts. The callback reconstructs a new expression in an +-- updated "response" environment. The response must be at least as large as +-- the computed usages. +occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t + -> (forall env'. OccEnv Occ env env' + -- response OccEnv must be at least as large as the OccEnv returned above + -> (forall env''. OccEnv () env env'' -> Ex env'' t') + -> r) + -> r +occCountX initialS topexpr k = case topexpr of + EVar _ t i -> + withSome (onehotOccEnv i (Occ One One) s) $ \env -> + k env $ \env' -> + withSome (occEnvPrjS env' i) $ \(Pair s' i') -> + projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') + ELet _ rhs body -> + occCountX s body $ \envB mkbody -> + occEnvPop' envB $ \envB' s1 -> + occCountX s1 rhs $ \envR mkrhs -> + withSome (Some envB' <> Some envR) $ \env -> + k env $ \env' -> + ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) + EPair _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsPair' s1 s2 -> + occCountX s1 a $ \env1 mka -> + occCountX s2 b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPair ext (mka env') (mkb env') + EFst _ e -> + occCountX (SsPair s SsNone) e $ \env1 mke -> + k env1 $ \env' -> + EFst ext (mke env') + ESnd _ e -> + occCountX (SsPair SsNone s) e $ \env1 mke -> + k env1 $ \env' -> + ESnd ext (mke env') + ENil _ -> + case s of + SsFull -> k OccEnd (\_ -> ENil ext) + SsNone -> k OccEnd (\_ -> ENil ext) + EInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + EInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + EInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + EInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + ECase _ e a b -> + occCountX s a $ \env1' mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env1' $ \env1 s1 -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) + ENothing _ t -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EJust _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsMaybe s' -> + occCountX s' e $ \env1 mke -> + k env1 $ \env' -> + EJust ext (mke env') + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EMaybe _ a b e -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsMaybe s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') + ELNil _ t1 t2 -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + ELInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + ELInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELCase _ e a b c -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occCountX s c $ \env3' mkc -> + occEnvPop' env2' $ \env2 s1 -> + occEnvPop' env3' $ \env3 s2 -> + occCountX (SsLEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> + k env $ \env' -> + ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) + + EConstArr _ n t x -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) + + EBuild _ n a b -> + case s of + SsNone -> + occCountX SsFull a $ \env1 mka -> + occCountX SsNone b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + use (EBuild ext n (mka env') $ + use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ + ENil ext) $ + ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX s' b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + EBuild ext n (mka env') $ + elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + + EFold1Inner _ commut a b c -> + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of + SsNone -> Some SsNone + SsArr' s' -> Some s' in + withSome (s1 <> s0) $ \sElt -> + occCountX sElt b $ \env2 mkb -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush env' () (SsPair sElt sElt))) + (mkb env') (mkc env') + + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e + + EUnit _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' s' -> + occCountX s' e $ \env mke -> + k env $ \env' -> + EUnit ext (mke env') + + EReplicate1Inner _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX (SsArr s') b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReplicate1Inner ext (mka env') (mkb env') + + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' SsNone s2) -> + occCountX SsNone a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ + emap (EPair ext (ENil ext) (evar IZ)) (mkb env') + SsArr' (SsPair' s1 SsNone) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ + emap (EPair ext (evar IZ) (ENil ext)) (mka env') + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + + EFold1InnerD1 _ cm e1 e2 e3 -> + case s of + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm ef ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkebog env') (mked env') + + EConst _ t x -> + k OccEnd $ \_ -> + case s of + SsNone -> ENil ext + SsFull -> EConst ext t x + + EIdx0 _ e -> + occCountX (SsArr s) e $ \env1 mke -> + k env1 $ \env' -> + EIdx0 ext (mke env') + + EIdx1 _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX (SsArr s') a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx1 ext (mka env') (mkb env') + + EIdx _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + _ -> + occCountX (SsArr s) a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx ext (mka env') (mkb env') + + EShape _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX (SsArr SsNone) e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EShape ext (mke env') + + EOp _ op e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EOp ext op (mke env') + + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') + + ERecompute _ e -> + occCountX s e $ \env1 mke -> + k env1 $ \env' -> + ERecompute ext (mke env') + + EWith _ t a b -> + case s of + SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with + occCountX SsNone b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + withSome (case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone) $ \s1' -> + occCountX s1' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ + ENil ext + SsPair sB sA -> + occCountX sB b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + let s1' = case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone in + withSome (Some sA <> s1') $ \sA' -> + occCountX sA' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ + EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) + SsFull -> occCountX (SsPair SsFull SsFull) topexpr k + + EAccum _ t p a sp b e -> + -- TODO: do better! + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull e $ \env3 mke -> + withSome (Some env1 <> Some env2) $ \env12 -> + withSome (Some env12 <> Some env3) $ \env -> + k env $ \env' -> + case s of {SsFull -> id; SsNone -> id} $ + EAccum ext t p (mka env') sp (mkb env') (mke env') + + EZero _ t e -> + occCountX (subZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EZero ext (applySubstrucM s t) (mke env') + where + subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) + subZeroInfo SsFull = SsFull + subZeroInfo SsNone = SsNone + subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) + subZeroInfo SsEither{} = error "Either is not a monoid" + subZeroInfo SsLEither{} = SsNone + subZeroInfo SsMaybe{} = SsNone + subZeroInfo (SsArr s') = SsArr (subZeroInfo s') + subZeroInfo SsAccum{} = error "Accum is not a monoid" + + EDeepZero _ t e -> + occCountX (subDeepZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EDeepZero ext (applySubstrucM s t) (mke env') + where + subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) + subDeepZeroInfo SsFull = SsFull + subDeepZeroInfo SsNone = SsNone + subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo SsEither{} = error "Either is not a monoid" + subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') + subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') + subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" + + EPlus _ t a b -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPlus ext (applySubstrucM s t) (mka env') (mkb env') + + EOneHot _ t p a b -> + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> -- TODO: do better + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') + + EError _ t msg -> + k OccEnd $ \_ -> EError ext (applySubstruc s t) msg + where + s = simplifySubstruc (typeOf topexpr) initialS + + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr' SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') + + +deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil (Some OccEnd) k = k SETop +deleteUnused (_ `SCons` env) (Some OccEnd) k = + deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = + deleteUnused env (Some occenv) $ \sub -> + case count of Zero -> k (SENo sub) + _ -> k (SEYesR sub) + +unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t +unsafeWeakenWithSubenv = \sub -> + subst (\x t i -> case sinkViaSubenv i sub of + Just i' -> EVar x t i' + Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") + where + sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) + sinkViaSubenv IZ (SEYesR _) = Just IZ + sinkViaSubenv IZ (SENo _) = Nothing + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs new file mode 100644 index 0000000..8e6b745 --- /dev/null +++ b/src/CHAD/AST/Env.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Env where + +import Data.Type.Equality + +import CHAD.AST.Sparse +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +-- | @env'@ is a subset of @env@: each element of @env@ is either included in +-- @env'@ ('SEYes') or not included in @env'@ ('SENo'). +data Subenv' s env env' where + SETop :: Subenv' s '[] '[] + SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') + SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () + => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') + => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} + +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' +subList SNil SETop = SNil +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) +subList (SCons _ xs) (SENo sub) = subList xs sub + +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env +subenvAll SNil = SETop +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) + +subenvNone :: SList f env -> Subenv' s env '[] +subenvNone SNil = SETop +subenvNone (SCons _ env) = SENo (subenvNone env) + +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} + +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 +subenvCompose SETop SETop = SETop +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) + +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') +subenvConcat sub1 SETop = sub1 +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) +subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) + +-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 +-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r +-- subenvSplit SNil sub k = k SETop sub +-- subenvSplit (SCons _ list) (SENo sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SENo sub1) sub2 +-- subenvSplit (SCons _ list) (SEYes s sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SEYes s sub1) sub2 + +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 +sinkWithSubenv SETop = WId +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SENo sub) = sinkWithSubenv sub + +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env +wUndoSubenv SETop = WId +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index 41da656..9ddcb35 100644 --- a/src/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -1,32 +1,31 @@ -{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where +module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) import Data.Functor.Const -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.String (fromString) import Prettyprinter import Prettyprinter.Render.String -import qualified Data.Text.Lazy as TL -import qualified Prettyprinter.Render.Terminal as PT +import Data.Text.Lazy qualified as TL +import Prettyprinter.Render.Terminal qualified as PT import System.Console.ANSI (hSupportsANSI) import System.IO (stdout) import System.IO.Unsafe (unsafePerformIO) -import AST -import AST.Count -import CHAD.Types -import Data +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types class PrettyX x where @@ -70,6 +69,7 @@ genNameIfUsedIn' prefix ty idx ex _ -> return "_" | otherwise = genName' prefix +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t @@ -201,16 +201,22 @@ ppExpr' d val expr = case expr of <> hardline <> e') (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] + EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - let opname = case cm of Commut -> "fold1i(C)" - Noncommut -> "fold1i" + let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -233,6 +239,38 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] + + EFold1InnerD1 _ cm a b c -> do + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -304,18 +342,24 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] EZero _ t e1 -> do e1' <- ppExpr' 11 val e1 return $ ppParen (d > 0) $ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b @@ -368,6 +412,20 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -461,4 +519,5 @@ render = else renderString) . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } where + {-# NOINLINE stdoutTTY #-} stdoutTTY = unsafePerformIO $ hSupportsANSI stdout diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs new file mode 100644 index 0000000..85f2882 --- /dev/null +++ b/src/CHAD/AST/Sparse.hs @@ -0,0 +1,296 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} +module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where + +import Data.Type.Equality + +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data (SBool(..)) + + +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = + eunPair e1 $ \w1 e1a e1b -> + eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> + EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) + (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = + elet e2 $ + elcase (weakenExpr WSink e1) + (evar IZ) + (elcase (evar (IS IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) + (elcase (evar (IS IZ)) + (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") + (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = + elet e2 $ + emaybe (weakenExpr WSink e1) + (evar IZ) + (emaybe (evar (IS IZ)) + (EJust ext (evar IZ)) + (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 + + +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) + + +data Injection sp a b where + -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that + -- 'sparsePlusS' can provide injections even if the caller doesn't require + -- them. This simplifies the sparsePlusS code. + Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b + Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 + -> ((forall e. Ex e a1 -> Ex e b1) + -> (forall e. Ex e a2 -> Ex e b2) + -> (forall e'. Ex e' a' -> Ex e' b')) + -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. TODO can this be improved? +sparsePlusS + :: SBool inj1 -> SBool inj2 + -> SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) + -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = + sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = + sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = + let ta = applySparse sp1 (fromSMTy t) in + sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = + let tb = applySparse sp2 (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = + let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in + sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + minj2 + (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = + let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = + let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in + sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = + let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t sp1 sp2 k + | Just Refl <- isDense t sp1 + , Just Refl <- isDense t sp2 + = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = + sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> + k sp3 Noinj (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k sp3 + (Inj $ \a -> emaybe a (inj2 zero2) (inj1 (evar IZ))) + (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) + | otherwise = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = + sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> + k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = + sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> + sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> + k (SpPair sp3a sp3b) + (withInj2 minj13a minj13b $ \inj13a inj13b -> + \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) + (withInj2 minj23a minj23b $ \inj23a inj23b -> + \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) + (\x1 x2 -> + eunPair x1 $ \w1 x1a x1b -> + eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> + EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = + sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> + sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> + let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) + inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) + inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) + in + k (SpLEither sp3a sp3b) + (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) + (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) + (\x1 x2 -> + elet x2 $ + elcase (weakenExpr WSink x1) + (elcase (evar IZ) + nil + (inl (inj23a (evar IZ))) + (inr (inj23b (evar IZ)))) + (elcase (evar (IS IZ)) + (inl (inj13a (evar IZ))) + (inl (plusa (evar (IS IZ)) (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) + (elcase (evar (IS IZ)) + (inr (inj13b (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") + (inr (plusb (evar (IS IZ)) (evar IZ))))) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> + k (SpArr sp3) + (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) + (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) + (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs new file mode 100644 index 0000000..8f41ba4 --- /dev/null +++ b/src/CHAD/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Sparse.Types where + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + +import CHAD.AST.Types + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs index 3c353d4..34267e4 100644 --- a/src/AST/SplitLets.hs +++ b/src/CHAD/AST/SplitLets.hs @@ -7,13 +7,13 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module AST.SplitLets (splitLets) where +module CHAD.AST.SplitLets (splitLets) where import Data.Type.Equality -import AST -import AST.Bindings -import Lemmas +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Lemmas splitLets :: Ex env t -> Ex env t @@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) @@ -34,7 +34,14 @@ splitLets' = \sub -> \case in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -49,11 +56,14 @@ splitLets' = \sub -> \case ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) @@ -63,8 +73,9 @@ splitLets' = \sub -> \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s @@ -88,15 +99,42 @@ splitLets' = \sub -> \case -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindings weakenExpr WSink bs1') + bs1 = fst (weakenBindingsE WSink bs1') (ptrs2, bs2) = split @(bind1 : env') tbind2 in letBinds bs1 $ - letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body + -- TODO: abstract this to splitN lol wtf + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + type family Split t where Split (TPair a b) = SplitRec (TPair a b) Split _ = '[] diff --git a/src/AST/Types.hs b/src/CHAD/AST/Types.hs index a3b7302..f0feb55 100644 --- a/src/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -5,10 +5,10 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeData #-} -module AST.Types where +module CHAD.AST.Types where import Data.Int (Int32, Int64) import Data.GADT.Compare @@ -16,7 +16,7 @@ import Data.GADT.Show import Data.Kind (Type) import Data.Type.Equality -import Data +import CHAD.Data type data Ty @@ -31,6 +31,8 @@ type data Ty type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool +-- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes +-- convenient, but such scalar types are not special in any way. type STy :: Ty -> Type data STy t where STNil :: STy TNil @@ -171,15 +173,25 @@ type family ScalIsIntegral t where ScalIsIntegral TBool = False -- | Returns true for arrays /and/ accumulators. -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STLEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = True +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True type family Tup env where Tup '[] = TNil diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs new file mode 100644 index 0000000..d3cad25 --- /dev/null +++ b/src/CHAD/AST/UnMonoid.hs @@ -0,0 +1,252 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where + +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data + + +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. +unMonoid :: Ex env t -> Ex env t +unMonoid = \case + EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e + EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) + EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) + + EVar _ t i -> EVar ext t i + ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) + EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) + EFst _ e -> EFst ext (unMonoid e) + ESnd _ e -> ESnd ext (unMonoid e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext t (unMonoid e) + EInr _ t e -> EInr ext t (unMonoid e) + ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) + ENothing _ t -> ENothing ext t + EJust _ e -> EJust ext (unMonoid e) + EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unMonoid e) + ELInr _ t e -> ELInr ext t (unMonoid e) + ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) + EConstArr _ n t x -> EConstArr ext n t x + EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) + ESum1Inner _ e -> ESum1Inner ext (unMonoid e) + EUnit _ e -> EUnit ext (unMonoid e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) + EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) + EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (unMonoid e) + EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) + EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) + EShape _ e -> EShape ext (unMonoid e) + EOp _ op e -> EOp ext op (unMonoid e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ERecompute _ e -> ERecompute ext (unMonoid e) + EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) + EError _ t s -> EError ext t s + +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +-- don't destroy the effects! +zero SMTNil e = ELet ext e $ ENil ext +zero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) +zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e +zero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil e = elet e $ ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + +plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t +-- don't destroy the effects! +plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext +plus (SMTPair t1 t2) a b = + eunPair a $ \w1 a1 a2 -> + eunPair (weakenExpr w1 b) $ \w2 b1 b2 -> + EPair ext (plus t1 (weakenExpr w2 a1) b1) + (plus t2 (weakenExpr w2 a2) b2) +plus (SMTLEither t1 t2) a b = + let t = STLEither (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + ELCase ext (EVar ext t (IS IZ)) + (EVar ext t IZ) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) + (EError ext t "plus l+r")) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (EError ext t "plus r+l") + (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) +plus (SMTMaybe t) a b = + ELet ext b $ + EMaybe ext + (EVar ext (STMaybe (fromSMTy t)) IZ) + (EJust ext + (EMaybe ext + (EVar ext (fromSMTy t) IZ) + (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) + (weakenExpr WSink a) +plus (SMTArr _ t) a b = + ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + a b +plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) + +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t +onehot typ topprj idx arg = case (typ, topprj) of + (_, SAPHere) -> + ELet ext arg $ + EVar ext (fromSMTy typ) IZ + + (SMTPair t1 t2, SAPFst prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t1 in + EPair ext (EVar ext toh IZ) + (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) + + (SMTPair t1 t2, SAPSnd prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t2 in + EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) + (EVar ext toh IZ) + + (SMTLEither t1 t2, SAPLeft prj) -> + ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) + (SMTLEither t1 t2, SAPRight prj) -> + ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) + + (SMTMaybe t1, SAPJust prj) -> + EJust ext (onehot t1 prj idx arg) + + (SMTArr n t1, SAPArrIdx prj) -> + let tidx = tTup (sreplicate n tIx) + in ELet ext idx $ + EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ + zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken.hs b/src/CHAD/AST/Weaken.hs index d882e28..ac0d152 100644 --- a/src/AST/Weaken.hs +++ b/src/CHAD/AST/Weaken.hs @@ -15,14 +15,15 @@ -- The reason why this is a separate module with "little" in it: {-# LANGUAGE AllowAmbiguousTypes #-} -module AST.Weaken (module AST.Weaken, Append) where +module CHAD.AST.Weaken (module CHAD.AST.Weaken, Append) where import Data.Bifunctor (first) import Data.Functor.Const +import Data.GADT.Compare import Data.Kind (Type) -import Data -import Lemmas +import CHAD.Data +import CHAD.Lemmas type Idx :: [k] -> k -> Type @@ -31,6 +32,11 @@ data Idx env t where IS :: Idx env t -> Idx (a : env) t deriving instance Show (Idx env t) +instance GEq (Idx env) where + geq IZ IZ = Just Refl + geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl + geq _ _ = Nothing + splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) splitIdx SNil i = Right i splitIdx (SCons _ _) IZ = Left IZ @@ -123,7 +129,7 @@ wCopies bs w = let bs' = slistMap (\_ -> Const ()) bs in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/CHAD/AST/Weaken/Auto.hs index 6752c24..229940b 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/CHAD/AST/Weaken/Auto.hs @@ -1,35 +1,34 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS_GHC -Wno-partial-type-signatures #-} -module AST.Weaken.Auto ( +module CHAD.AST.Weaken.Auto ( autoWeak, (&.), auto, auto1, Layout(..), ) where import Data.Functor.Const +import Data.Kind (Constraint) import GHC.OverloadedLabels import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import AST.Weaken -import Data -import Lemmas +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Lemmas type family Lookup name list where @@ -39,18 +38,21 @@ type family Lookup name list where -- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) +-- occurrences within this layout. 'names' is the list of names that this +-- layout /produces/. That is: for LPreW, it contains the target name. The +-- 'names' list of a source layout must be a subset of the names list of the +-- target layout (which cannot contain LPreW); this is checked with SubLayout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) -- | Pre-weaken with a weakening LPreW :: forall name1 name2 segments. SegmentName name1 -> SegmentName name2 -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments (Lookup name1 segments) - (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2) + -> Layout True segments '[name2] (Lookup name1 segments) + (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) infixr :++: -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where +instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where fromLabel = LSeg (symbolSing @name) newtype SegmentName name = SegmentName (SSymbol name) @@ -60,11 +62,23 @@ instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') wh fromLabel = SegmentName symbolSing +type family SubLayout names1 names2 where + SubLayout '[] _ = () :: Constraint + SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 +type family SubLayout' n ok names1 names2 where + SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") + SubLayout' _ True names1 names2 = SubLayout names1 names2 +type family Contains n names where + Contains _ '[] = False + Contains n (n : _) = True + Contains n (_ : names) = Contains n names + + data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil auto :: KnownListSpine list => SList (Const ()) list @@ -74,7 +88,7 @@ auto1 :: SList (Const ()) '[t] auto1 = Const () `SCons` SNil infixr &. -(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) (&.) = ssegmentsAppend where ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) @@ -118,12 +132,12 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ seg) +lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinApp name LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinAppPreW name1 name2 w LinEnd @@ -151,8 +165,7 @@ pullDown segs name@SSymbol linlayout kNotFound k = k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) .> wCopies (segmentLookup segs n') w) -sortLinLayouts :: forall segments env1 env2. - SSegments segments +sortLinLayouts :: SSegments segments -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) @@ -169,8 +182,8 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" -autoWeak :: forall segments env1 env2. - SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak :: SubLayout names1 names2 + => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 autoWeak segs ly1 ly2 = preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs deleted file mode 100644 index d8a71b5..0000000 --- a/src/CHAD/Accum.hs +++ /dev/null @@ -1,27 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -module CHAD.Accum where - -import AST -import CHAD.Types -import Data - - - -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t = - makeAccumulators envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - diff --git a/src/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs index 4501c32..212cc7d 100644 --- a/src/Analysis/Identity.hs +++ b/src/CHAD/Analysis/Identity.hs @@ -3,7 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -module Analysis.Identity ( +module CHAD.Analysis.Identity ( identityAnalysis, identityAnalysis', ValId(..), @@ -13,11 +13,11 @@ module Analysis.Identity ( import Data.Foldable (toList) import Data.List (intercalate) -import AST -import AST.Pretty (PrettyX(..)) -import CHAD.Types (d1, d2) -import Data -import Util.IdGen +import CHAD.AST +import CHAD.AST.Pretty (PrettyX(..)) +import CHAD.Data +import CHAD.Drev.Types (d1, d2) +import CHAD.Util.IdGen -- | Every array, scalar and accumulator has an ID. Trivial values such as @@ -202,11 +202,19 @@ idana env expr = case expr of res <- VIArr <$> genId <*> shidsToVec dim shids pure (res, EBuild res dim e1' e2') + EMap _ e1 e2 -> do + let STArr _ t = typeOf e2 + x1 <- genIds t + (_, e1') <- idana (x1 `SCons` env) e1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure sh + pure (res, EMap res e1' e2') + EFold1Inner _ cm e1 e2 e3 -> do let t1 = typeOf e1 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 @@ -244,6 +252,41 @@ idana env expr = case expr of res <- VIArr <$> genId <*> pure sh pure (res, EMinimum1Inner res e1') + EReshape _ dim e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> shidsToVec dim v1 + pure (res, EReshape res dim e1' e2') + + EZip _ e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + let VIArr _ sh = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EZip res e1' e2') + + EFold1InnerD1 _ cm e1 e2 e3 -> do + let t1 = typeOf e2 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ sh'@(_ :< sh) = v3 + res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') + pure (res, EFold1InnerD1 res cm e1' e2' e3') + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + xf1 <- genIds t2 + xf2 <- genIds tB + (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef + (v2, e2') <- idana env ebog + (_, e3') <- idana env ed + let VIArr _ sh@(_ :< sh') = v2 + res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3') + EConst _ t val -> do res <- VIScal <$> genId pure (res, EConst res t val) @@ -307,11 +350,11 @@ idana env expr = case expr of let res = VIPair v2 x2 pure (res, EWith res t e1' e2') - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - pure (VINil, EAccum VINil t prj e1' e2' e3') + pure (VINil, EAccum VINil t prj e1' sp e2' e3') EZero _ t e1 -> do -- Approximate the result of EZero to be independent from the zero info @@ -320,6 +363,13 @@ idana env expr = case expr of res <- genIds (fromSMTy t) pure (res, EZero res t e1') + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') + EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 diff --git a/src/Array.hs b/src/CHAD/Array.hs index 707dce2..caf63ef 100644 --- a/src/Array.hs +++ b/src/CHAD/Array.hs @@ -2,19 +2,20 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} -module Array where +module CHAD.Array where import Control.DeepSeq import Control.Monad.Trans.State.Strict import Data.Foldable (traverse_) import Data.Vector (Vector) -import qualified Data.Vector as V +import Data.Vector qualified as V import GHC.Generics (Generic) -import Data +import CHAD.Data data Shape n where @@ -91,6 +92,11 @@ arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) arrayToList :: Array n t -> [t] arrayToList (Array _ v) = V.toList v +arrayReshape :: Shape n -> Array m t -> Array n t +arrayReshape sh (Array sh' v) + | shapeSize sh == shapeSize sh' = Array sh v + | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" + arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) diff --git a/src/Compile.hs b/src/CHAD/Compile.hs index 722b432..44a335c 100644 --- a/src/Compile.hs +++ b/src/CHAD/Compile.hs @@ -2,13 +2,14 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} -module Compile (compile) where +module CHAD.Compile (compile, compileStderr) where import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) @@ -20,35 +21,37 @@ import Data.Bifunctor (first) import Data.Char (ord) import Data.Foldable (toList) import Data.Functor.Const -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.Functor.Product (Product) import Data.IORef import Data.List (foldl1', intersperse, intercalate) -import qualified Data.Map.Strict as Map +import Data.Map.Strict qualified as Map import Data.Maybe (fromMaybe) -import qualified Data.Set as Set +import Data.Set qualified as Set import Data.Set (Set) import Data.Some -import qualified Data.Vector as V +import Data.Vector qualified as V import Foreign import GHC.Exts (int2Word#, addr2Int#) import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) import Numeric (showHex) import System.IO (hPutStrLn, stderr) import System.IO.Error (mkIOError, userErrorType) import System.IO.Unsafe (unsafePerformIO) import Prelude hiding ((^)) -import qualified Prelude +import Prelude qualified -import Array -import AST -import AST.Pretty (ppSTy, ppExpr) -import Compile.Exec -import Data -import Interpreter.Rep -import qualified Util.IdGen as IdGen +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty (ppSTy, ppExpr) +import CHAD.AST.Sparse.Types (isDense) +import CHAD.Compile.Exec +import CHAD.Data +import CHAD.Interpreter.Rep +import CHAD.Util.IdGen qualified as IdGen -- In shape and index arrays, the innermost dimension is on the right (last index). @@ -69,28 +72,30 @@ debugAllocs :: Bool; debugAllocs = toEnum 0 -- | Emit extra C code that checks stuff emitChecks :: Bool; emitChecks = toEnum 0 +-- | Returns compiled function plus compilation output (warnings) compile :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t)) + -> IO (SList Value env -> IO (Rep t), String) compile = \env expr -> do codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) let (source, offsets) = compileToString codeID env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - lib <- buildKernel source "kernel" + (lib, compileOutput) <- buildKernel source "kernel" let result_type = typeOf expr result_size = sizeofSTy result_type - return $ \val -> do - allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) - serialiseArguments args ptr $ do - callKernelFun lib ptr - ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) - when (ok /= 1) $ - ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) - deserialise result_type ptr (koResultOffset offsets) + let function val = do + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) + serialiseArguments args ptr $ do + callKernelFun lib ptr + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) + return (function, compileOutput) where serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -98,6 +103,15 @@ compile = \env expr -> do serialiseArguments args ptr k serialiseArguments _ _ k = k +-- | 'compile', but writes any produced C compiler output to stderr. +compileStderr :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t)) +compileStderr env expr = do + (fun, output) <- compile env expr + when (not (null output)) $ + hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun + data StructDecl = StructDecl String -- ^ name @@ -125,7 +139,7 @@ data CExpr | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr - | CECast String CExpr -- ^ (<type)<expr> + | CECast String CExpr -- ^ (<type>)<expr> deriving (Show) printStructDecl :: StructDecl -> ShowS @@ -214,23 +228,31 @@ repSTy (STScal st) = case st of STBool -> "uint8_t" repSTy t = genStructName t -genStructName :: STy t -> String -genStructName = \t -> "ty_" ++ gen t where - -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STLEither a b) = 'L' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen (fromSMTy t) +genStructName, genArrBufStructName :: STy t -> String +(genStructName, genArrBufStructName) = + (\t -> "ty_" ++ gen t + ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension + t -> error $ "genArrBufStructName: not an array type: " ++ show t) + where + -- all tags start with a letter, so the array mangling is unambiguous. + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) + +-- The subtrees contain structs used in the bodies of the structs in this node. +data StructTree = TreeNode [StructDecl] [StructTree] + deriving (Show) -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the @@ -238,60 +260,56 @@ genStructName = \t -> "ty_" ++ gen t where -- -- For accumulation it is important that for struct representations of monoid -- types, the all-zero-bytes value corresponds to the zero value of that type. -genStruct :: String -> STy t -> [StructDecl] -genStruct name topty = case topty of +buildStructTree :: STy t -> StructTree +buildStructTree topty = case topty of STNil -> - [StructDecl name "" com] + TreeNode [StructDecl name "" com] [] STPair a b -> - [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + [buildStructTree a, buildStructTree b] STEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] STMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + [buildStructTree t] STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - -- TODO: put shape in the main struct, not the buffer; it's constant, after all -- TODO: no buffer if n = 0 - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" - ,StructDecl name (name ++ "_buf *buf;") com] + TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" + ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] + [buildStructTree t] STScal _ -> - [] + TreeNode [] [] STAccum t -> - [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" - ,StructDecl name (name ++ "_buf *buf;") com] + TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] + [buildStructTree (fromSMTy t)] where + name = genStructName topty com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () -genStructs ty = do - let name = genStructName ty - seen <- lift $ gets (name `Set.member`) - - if seen - then pure () - else do - -- already mark this struct as generated now, so we don't generate it - -- twice (unnecessary because no recursive types, but y'know) - lift $ modify (Set.insert name) - - () <- case ty of - STNil -> pure () - STPair a b -> genStructs a >> genStructs b - STEither a b -> genStructs a >> genStructs b - STLEither a b -> genStructs a >> genStructs b - STMaybe t -> genStructs t - STArr _ t -> genStructs t - STScal _ -> pure () - STAccum t -> genStructs (fromSMTy t) - - tell (BList (genStruct name ty)) +genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () +genStructTreeW (TreeNode these deps) = do + seen <- lift get + case filter ((`Set.notMember` seen) . nameOf) these of + [] -> pure () + structs -> do + lift $ modify (Set.fromList (map nameOf structs) <>) + mapM_ genStructTreeW deps + tell (BList structs) + where + nameOf (StructDecl name _ _) = name genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty +genAllStructs tys = + let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys + in toList (evalState (execWriterT m) mempty) data CompState = CompState { csStructs :: Set (Some STy) @@ -340,6 +358,12 @@ emitStruct ty = CompM $ do modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } return (genStructName ty) +-- | Also returns the name of the array buffer struct +emitArrStruct :: STy t -> CompM (String, String) +emitArrStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty, genArrBufStructName ty) + emitTLD :: String -> CompM () emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } @@ -427,10 +451,10 @@ compileToString codeID env expr = else id ,showString $ " const bool success = typed_kernel(" ++ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ - concat (map (\((arg, typ), off) -> - ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" - ++ " /* " ++ arg ++ " */") - (zip arg_pairs arg_offsets)) ++ + concat (zipWith (\(arg, typ) off -> + ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ " /* " ++ arg ++ " */") + arg_pairs arg_offsets) ++ "\n );\n" ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" @@ -480,19 +504,18 @@ serialise topty topval ptr off k = serialise t x ptr (off + alignmentSTy t) k (STArr n t, Array sh vec) -> do let eltsz = sizeofSTy t - allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do when debugRefc $ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr pokeByteOff ptr off bufptr + pokeShape ptr (off + 8) n sh - pokeShape bufptr 0 n sh - pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63) + pokeByteOff @Word64 bufptr 0 (2 ^ 63) - let off1 = fromSNat n * 8 + 8 - loop i + let loop i | i == shapeSize sh = k | otherwise = - serialise t (vec V.! i) bufptr (off1 + i * eltsz) $ + serialise t (vec V.! i) bufptr (8 + i * eltsz) $ loop (i+1) loop 0 (STScal sty, x) -> case sty of @@ -532,13 +555,12 @@ deserialise topty ptr off = else Just <$> deserialise t ptr (off + alignmentSTy t) STArr n t -> do bufptr <- peekByteOff @(Ptr ()) ptr off - sh <- peekShape bufptr 0 n - refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + sh <- peekShape ptr (off + 8) n + refc <- peekByteOff @Word64 bufptr 0 when debugRefc $ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc - let off1 = 8 * fromSNat n + 8 - eltsz = sizeofSTy t - arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) + let eltsz = sizeofSTy t + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) when (refc < 2 ^ 62) $ free bufptr return arr STScal sty -> case sty of @@ -576,7 +598,7 @@ metricsSTy (STLEither a b) = metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr _ _) = (8, 8) +metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) metricsSTy (STScal sty) = case sty of STI32 -> (4, 4) STI64 -> (8, 8) @@ -599,7 +621,7 @@ peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) peekShape ptr off = \case SZ -> return ShNil SS n -> ShCons <$> peekShape ptr off n - <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) + <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + fromSNat n * 8)) compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case @@ -747,15 +769,17 @@ compile' env = \case return (CELit retvar) EConstArr _ n t (Array sh vec) -> do - strname <- emitStruct (STArr n (STScal t)) + (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" -- Give it a refcount of _half_ the size_t max, so that it can be -- incremented and decremented at will and will "never" reach anything -- where something happens - emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++ - "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ - ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" - return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) + emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ + "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ + ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname + [("buf", CEAddrOf (CELit tldname)) + ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) EBuild _ n esh efun -> do shname <- compileAssign "sh" env esh @@ -770,7 +794,7 @@ compile' env = \case emit $ SBlock $ pure (SVarDecl False "size_t" linivar (CELit "0")) <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx)))) + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) | (ivar, dimidx) <- zip ivars [0::Int ..]] (pure (SVarDecl True (repSTy (typeOf esh)) idxargname (shapeTupFromLitVars n ivars)) @@ -779,6 +803,15 @@ compile' env = \case return (CELit arrname) + -- TODO: actually generate decent code here + EMap _ e1 e2 -> do + let STArr n _ = typeOf e2 + compile' env $ + elet e2 $ + EBuild ext n (EShape ext (evar IZ)) $ + elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e1 + EFold1Inner _ commut efun ex0 earr -> do let STArr (SS n) t = typeOf earr @@ -799,7 +832,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name @@ -808,22 +841,26 @@ compile' env = \case -- kvar <- if vecwid > 1 then genName' "k" else return "" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> funStmts - <> pure (SAsg accvar funres)) + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SAsg accvar funres)) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) incrementVarAlways "foldx0" Decrement t x0name @@ -845,7 +882,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) let vecwid = 8 :: Int ivar <- genName' "i" @@ -909,6 +946,149 @@ compile' env = \case EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + + -- TODO: actually generate decent code here + EZip _ e1 e2 -> do + let STArr n _ = typeOf e1 + compile' env $ + elet e1 $ + elet (weakenExpr WSink e2) $ + EBuild ext n (EShape ext (evar (IS IZ))) $ + EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) + + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + pairstrname <- emitStruct (STPair t t) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -934,7 +1114,7 @@ compile' env = \case when emitChecks $ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" - (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ arrname ++ ".buf); return false;") @@ -990,7 +1170,7 @@ compile' env = \case accname <- genName' "accum" emit $ SVarDecl False actyname accname (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) - emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) + emit $ SAsg (accname++".buf->ac") (fromMaybe (CELit name1) mcopy) emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." e2' <- compile' (Const accname `SCons` env) e2 @@ -1002,95 +1182,7 @@ compile' env = \case rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - EAccum _ t prj eidx eval eacc -> do - let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a - -- full zero array with the given zero info (for the type SMTArr n t1). - initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () - initZeroArray n t1 v vzi = do - shszname <- genName' "inacshsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi) - newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi) - emit $ SAsg v (CELit newarrName) - forM_ (initZeroFromMemset t1) $ \f1 -> do - ivar <- genName' "i" - ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts - - -- If something needs to be done to properly initialise this type to - -- zero after memory has already been initialised to all-zero bytes, - -- returns an action that does so. - -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ()) - initZeroFromMemset SMTNil = Nothing - initZeroFromMemset (SMTPair t1 t2) = - case (initZeroFromMemset t1, initZeroFromMemset t2) of - (Nothing, Nothing) -> Nothing - (mf1, mf2) -> Just $ \v vzi -> do - forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a") - forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b") - initZeroFromMemset SMTLEither{} = Nothing - initZeroFromMemset SMTMaybe{} = Nothing - initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi - initZeroFromMemset SMTScal{} = Nothing - - let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroZI :: SMTy a -> String -> String -> CompM () - initZeroZI SMTNil _ _ = return () - initZeroZI (SMTPair t1 t2) v vzi = do - initZeroZI t1 (v++".a") (vzi++".a") - initZeroZI t2 (v++".b") (vzi++".b") - initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi - initZeroZI (SMTScal sty) v _ = case sty of - STI32 -> emit $ SAsg v (CELit "0") - STI64 -> emit $ SAsg v (CELit "0l") - STF32 -> emit $ SAsg v (CELit "0.0f") - STF64 -> emit $ SAsg v (CELit "0.0") - - let -- Initialise an uninitialised accumulation value, potentially already - -- with the addend, potentially to zero depending on the nature of the - -- projection. - -- 1. If the projection indexes only through dense monoids before - -- reaching SAPHere, the thing cannot be initialised to zero with - -- only an AcIdx; it would need to model a zero after the addend, - -- which is stupid and redundant. In this case, we return Left: - -- (accumulation value) (AcIdx value) (addend value). - -- The addend is copied, not consumed. (We can't reliably _always_ - -- consume it, so it's not worth trying to do it sometimes.) - -- 2. Otherwise, a sparse monoid is found along the way, and we can - -- initalise the dense prefix of the path to zero by setting the - -- indexed-through sparse value to a sparse zero. Afterwards, the - -- main recursion can proceed further. In this case, we return - -- Right: (accumulation value) (AcIdx value) - -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) - initZeroChunk :: SMTy a -> SAcPrj p a b - -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend - (String -> String -> CompM ()) -- zero initialisation of sparse chunk - initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of - -- reached target before the first sparse constructor - (t1 , SAPHere ) -> Left $ \v _ addend -> do - incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend - emit $ SAsg v (CELit addend) - -- sparse types - (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - -- dense types - (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - f (v++".a") (i++".a") - initZeroZI t2 (v++".b") (i++".b") - (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do - initZeroZI t1 (v++".a") (i++".a") - f (v++".b") (i++".b") - (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - initZeroArray n t1 v (i++".a.b") - linidxvar <- genName' "li" - emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) - f (v++".buf->xs["++linidxvar++"]") (i++".b") - where - applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i - applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i - + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do let -- Add a value (s) into an existing accumulation value (d). If a sparse -- component of d is encountered, s is copied there. add :: SMTy a -> String -> String -> CompM () @@ -1131,16 +1223,16 @@ compile' env = \case when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" forM_ [0 .. fromSNat n - 1] $ \j -> do - emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]")) + emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) "!=" - (CELit (d ++ ".buf->sh[" ++ show j ++ "]"))) + (CELit (d ++ ".sh[" ++ show j ++ "]"))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ d ++ ".buf" ++ - concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ ", " ++ s ++ ".buf" ++ - concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ "); " ++ "return false;") mempty @@ -1160,67 +1252,55 @@ compile' env = \case accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend - accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend - accumRef (SMTLEither ta tb) prj0 v i addend = do - let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj' - SAPRight prj' -> initZeroChunk tb prj' - subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") - tagval = case prj0 of SAPLeft{} -> "1" - SAPRight{} -> "2" - ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend - SAPRight prj' -> accumRef tb prj' subv i addend - case chunkres of - Left densef -> do - ((), stmtsSet) <- scope $ densef subv i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) - stmtsAdd -- TODO: emit check for consistency of tags? - Right sparsef -> do - ((), stmtsInit) <- scope $ sparsef subv i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty - forM_ stmtsAdd emit + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do - case initZeroChunk tj prj' of - Left densef -> do - ((), stmtsSet1) <- scope $ densef (v++".j") i addend - ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) - stmtsAdd1 - Right sparsef -> do - ((), stmtsInit1) <- scope $ sparsef (v++".j") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef tj prj' (v++".j") i addend + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ v ++ ".buf" ++ - concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ + concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend nameidx <- compileAssign "acidx" env eidx nameval <- compileAssign "acval" env eval @@ -1234,6 +1314,9 @@ compile' env = \case return $ CEStruct (repSTy STNil) [] + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + EError _ t s -> do let padleft len c s' = replicate (len - length s) c ++ s' escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] @@ -1247,6 +1330,7 @@ compile' env = \case return $ CEStruct name [] EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" @@ -1363,21 +1447,21 @@ toLinearIdx SZ _ _ = CELit "0" toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") toLinearIdx (SS n) arrvar idxvar = CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) - "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n))))) + "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) "+" (CELit (idxvar ++ ".b")) -- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr -- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] -- fromLinearIdx (SS n) arrvar idxvar = do -- name <- genName --- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) -- _ data AllocMethod = Malloc | Calloc deriving (Show) -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" @@ -1392,9 +1476,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) emit $ SVarDecl True strname arrname $ CEStruct strname [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] - Calloc -> CECall "calloc_instr" [nbytesExpr])] - forM_ (zip shape [0::Int ..]) $ \(dim, i) -> - emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim + Calloc -> CECall "calloc_instr" [nbytesExpr]) + ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") when debugRefc $ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" @@ -1405,16 +1488,16 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] compileShapeQuery (SS n) var = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", compileShapeQuery n var) - ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] + ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) -- | Takes a variable name for the array, not the buffer. compileArrShapeComponents :: SNat n -> String -> [CExpr] compileArrShapeComponents n var = - [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] indexTupleComponents :: SNat n -> String -> [CExpr] indexTupleComponents = \n var -> map CELit (toList (go n var)) @@ -1433,6 +1516,9 @@ shapeTupFromLitVars = \n -> go n . reverse go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @CompM $ CECall cop [e1] @@ -1505,7 +1591,7 @@ compileExtremum nameBase opName operator env e = do lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" @@ -1576,7 +1662,7 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - SMTArr n t | not (hasArrays (fromSMTy t)) -> do + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" emit $ SVarDeclUninit toptyname name @@ -1585,7 +1671,7 @@ copyForWriting topty var = case topty of let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" emit $ SVerbatim $ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ - concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ ");" emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) @@ -1596,8 +1682,7 @@ copyForWriting topty var = case topty of in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ printCExpr 0 databytes ");"]) @@ -1614,8 +1699,7 @@ copyForWriting topty var = case topty of name <- genName emit $ SVarDecl False toptyname name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain diff --git a/src/Compile/Exec.hs b/src/CHAD/Compile/Exec.hs index 9b9fb15..ffe5661 100644 --- a/src/Compile/Exec.hs +++ b/src/CHAD/Compile/Exec.hs @@ -1,6 +1,5 @@ {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TupleSections #-} -module Compile.Exec ( +module CHAD.Compile.Exec ( KernelLib, buildKernel, callKernelFun, @@ -30,7 +29,7 @@ debug = False -- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) -buildKernel :: String -> String -> IO KernelLib +buildKernel :: String -> String -> IO (KernelLib, String) buildKernel csource funname = do template <- (++ "/tmp.chad.") <$> getTempDir path <- mkdtemp template @@ -42,7 +41,9 @@ buildKernel csource funname = do ,"-o", outso, "-" ,"-Wall", "-Wextra" ,"-Wno-unused-variable", "-Wno-unused-but-set-variable" - ,"-Wno-unused-parameter", "-Wno-unused-function"] + ,"-Wno-unused-parameter", "-Wno-unused-function" + ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives + ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource -- Print the source before the GCC output. @@ -50,11 +51,6 @@ buildKernel csource funname = do ExitSuccess -> return () ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>" - when (not (null gccStdout)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stdout: <<<\n" ++ gccStdout ++ ">>>" - when (not (null gccStderr)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stderr: <<<\n" ++ gccStderr ++ ">>>" - case ec of ExitSuccess -> return () ExitFailure{} -> do @@ -71,7 +67,7 @@ buildKernel csource funname = do _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" dlclose dl) - return (KernelLib ref) + return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr) foreign import ccall "dynamic" wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () diff --git a/src/Data.hs b/src/CHAD/Data.hs index e86aaa6..8c7605c 100644 --- a/src/Data.hs +++ b/src/CHAD/Data.hs @@ -8,16 +8,17 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl)) where +module CHAD.Data (module CHAD.Data, (:~:)(Refl), If) where import Data.Functor.Product import Data.GADT.Compare import Data.GADT.Show import Data.Some +import Data.Type.Bool (If) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) -import Lemmas (Append) +import CHAD.Lemmas (Append) data Dict c where @@ -184,3 +185,8 @@ instance Applicative Bag where instance Semigroup (Bag t) where (<>) = BTwo instance Monoid (Bag t) where mempty = BNone + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) diff --git a/src/Data/VarMap.hs b/src/CHAD/Data/VarMap.hs index 9c10421..a0d7617 100644 --- a/src/Data/VarMap.hs +++ b/src/CHAD/Data/VarMap.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} -module Data.VarMap ( +module CHAD.Data.VarMap ( VarMap, empty, insert, @@ -20,16 +21,16 @@ module Data.VarMap ( import Prelude hiding (lookup) -import qualified Data.Map.Strict as Map +import Data.Map.Strict qualified as Map import Data.Map.Strict (Map) import Data.Maybe (mapMaybe) import Data.Some -import qualified Data.Vector.Storable as VS +import Data.Vector.Storable qualified as VS import Unsafe.Coerce -import AST.Env -import AST.Types -import AST.Weaken +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken type role VarMap _ nominal -- ensure that 'env' is not phantom @@ -74,7 +75,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' subMap subenv = let bools = let loop :: Subenv env env' -> [Bool] loop SETop = [] - loop (SEYes sub) = True : loop sub + loop (SEYesR sub) = True : loop sub loop (SENo sub) = False : loop sub in VS.fromList $ loop subenv newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools @@ -89,7 +90,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env superMap subenv = let loop :: Subenv env env' -> Int -> [Int] loop SETop _ = [] - loop (SEYes sub) i = i : loop sub (i+1) + loop (SEYesR sub) i = i : loop sub (i+1) loop (SENo sub) i = loop sub (i+1) newIndices = VS.fromList $ loop subenv 0 diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs new file mode 100644 index 0000000..bfa964b --- /dev/null +++ b/src/CHAD/Drev.hs @@ -0,0 +1,1581 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module CHAD.Drev ( + drev, + freezeRet, + CHADConfig(..), + defaultConfig, + Storage(..), + Descr(..), + Select, +) where + +import Data.Functor.Const +import Data.Some +import Data.Type.Equality (type (==), testEquality) + +import CHAD.Analysis.Identity (ValId(..), validSplitEither) +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.AST.Count +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.Weaken.Auto +import CHAD.Data +import CHAD.Data.VarMap qualified as VarMap +import CHAD.Data.VarMap (VarMap) +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types +import CHAD.Lemmas + + +------------------------------ TAPES AND BINDINGS ------------------------------ + +type family Tape binds where + Tape '[] = TNil + Tape (t : ts) = TPair t (Tape ts) + +tapeTy :: SList STy binds -> STy (Tape binds) +tapeTy SNil = STNil +tapeTy (SCons t ts) = STPair t (tapeTy ts) + +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds + -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = + EPair ext (EVar ext t (w @> IZ)) + (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = + bindingsCollectTape binds sub (w .> WSink) + +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +-- | Refl <- lemAppendNil @binds +-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) + +-- In order from large to small: i.e. in reverse order from what we want, +-- because in a Bindings, the head of the list is the bottom-most entry. +type family TapeUnfoldings binds where + TapeUnfoldings '[] = '[] + TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts + +type family Reverse l where + Reverse '[] = '[] + Reverse (t : ts) = Append (Reverse ts) '[t] + +-- An expression that is always 'snd' +data UnfExpr env t where + UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t + +fromUnfExpr :: UnfExpr env t -> Ex env t +fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) + +-- - A bunch of 'snd' expressions taking us from knowing that there's a +-- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix +-- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in +-- the environment. +-- - In the extended environment, another bunch of let bindings (these are +-- 'fst' expressions, but no need to know that statically) that project the +-- fsts out of what we introduced above, one for each type in 'ts'. +data Reconstructor env ts = + Reconstructor + (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) + (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) + +ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) +ssnoc SNil a = SCons a SNil +ssnoc (SCons t ts) a = SCons t (ssnoc ts a) + +sreverse :: SList f ts -> SList f (Reverse ts) +sreverse SNil = SNil +sreverse (SCons t ts) = ssnoc (sreverse ts) t + +stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) +stapeUnfoldings SNil = SNil +stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) + +-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. +shiftUnfolder + :: STy t + -> SList STy ts + -> Bindings UnfExpr (Tape ts : env) list + -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) +shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) +shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = + -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order + -- to expand an 'Append' in the types so that things simplify just enough. + -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list + -- of bindings produced by 'b'. We want to conclude from this that + -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know + -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after + -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. + BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t + BPush{} -> UnfExSnd itemTy t) + +growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) +growRecon t ts (Reconstructor unfbs bs) + | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) + , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) + , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env + = Reconstructor + (shiftUnfolder t ts unfbs) + -- Add a 'fst' at the bottom of the builder stack. + -- First we have to weaken most of 'bs' to skip one more binding in the + -- unfolder stack above it. + (BPush (fst (weakenBindingsE + (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) + (WSink :: env :> (Tape (t : ts) : env))) bs)) + (t + ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ + wSinks @(Tape (t : ts) : env) + (sappend ts + (sappend (sappend (sreverse (stapeUnfoldings ts)) + (SCons (tapeTy ts) SNil)) + SNil)) + @> IZ)) + +buildReconstructor :: SList STy ts -> Reconstructor env ts +buildReconstructor SNil = Reconstructor BTop BTop +buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) + +-- STRATEGY FOR reconstructBindings +-- +-- binds = [] +-- e : () +-- +-- binds = [c] +-- e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst e : c +-- +-- binds = [b, c] +-- e : (b, (c, ())) +-- x1 = snd e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- +-- binds = [a, b, c] +-- e : (a, (b, (c, ()))) +-- x2 = snd e : (b, (c, ())) +-- x1 = snd x2 : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- y3 = fst x3 : a + +-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all +-- the things in the list 'binds', we want to create a let stack that extracts +-- all values from that tuple and in effect "restores" the environment +-- described by 'binds'. The idea is that elsewhere, we took a slice of the +-- environment and saved it all in a tuple to be restored later. We +-- incidentally also add a bunch of additional bindings, namely 'Reverse +-- (TapeUnfoldings binds)', so the calling code just has to skip those in +-- whatever it wants to do. +reconstructBindings :: SList STy binds + -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) + ,SList STy (Reverse (TapeUnfoldings binds))) +reconstructBindings binds = + (\tape -> let Reconstructor unf build = buildReconstructor binds + in fst $ weakenBindingsE (WIdx tape) + (bconcat (mapBindings fromUnfExpr unf) build) + ,sreverse (stapeUnfoldings binds)) + + +---------------------------------- DERIVATIVES --------------------------------- + +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e +d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e +d1op (ORecip t) e = EOp ext (ORecip t) e +d1op (OExp t) e = EOp ext (OExp t) e +d1op (OLog t) e = EOp ext (OLog t) e +d1op (OIDiv t) e = EOp ext (OIDiv t) e +d1op (OMod t) e = EOp ext (OMod t) e + +-- | Both primal and dual must be duplicable expressions +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) + | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) + +d2op :: SOp a t -> D2Op a t +d2op op = case op of + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d + OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) + ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t + ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OIf -> Linear $ \_ -> ENil ext + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) + OToFl64 -> Linear $ \_ -> ENil ext + ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) + OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) + OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + + d2opUnArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TScal a) t) + -> D2Op (TScal a) t + d2opUnArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> ENil ext + STI64 -> Linear $ \_ -> ENil ext + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> ENil ext + + d2opBinArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) + -> D2Op (TPair (TScal a) (TScal a)) t + d2opBinArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + + floatingD2 :: ScalIsFloating a ~ True + => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r + floatingD2 STF32 k = k + floatingD2 STF64 k = k + + integralD2 :: ScalIsIntegral a ~ True + => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r + integralD2 STI32 k = k + integralD2 STI64 k = k + +desD1E :: Descr env sto -> SList STy (D1E env) +desD1E = d1e . descrList + +-- d1W :: env :> env' -> D1E env :> D1E env' +-- d1W WId = WId +-- d1W WSink = WSink +-- d1W (WCopy w) = WCopy (d1W w) +-- d1W (WPop w) = WPop (d1W w) +-- d1W (WThen u w) = WThen (d1W u) (d1W w) + +conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) +conv1Idx IZ = IZ +conv1Idx (IS i) = IS (conv1Idx i) + +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, _, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) +conv2Idx DTop i = case i of {} + +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 + where + go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) + go (STScal STI32) SpAbsent = \_ -> ENil ext + go (STScal STI64) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) + go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) + go (STScal STBool) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpScal = id + go (STScal STF64) SpScal = id + go STNil _ = \_ -> ENil ext + go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) + go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" + + +----------------------------------- SPARSITY ----------------------------------- + +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +subenvPlus :: SBool req1 -> SBool req2 + -> SList SMTy env + -> SubenvS env env1 -> SubenvS env env2 + -> (forall env3. SubenvS env env3 + -> Injection req1 (Tup env1) (Tup env3) + -> Injection req2 (Tup env2) (Tup env3) + -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) + -> r) + -> r +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SENo sub3) s31 s32 pl + +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + Noinj + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = + subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> + k sub3 minj13 minj23 (flip pl) + +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> + k (SEYes sp3 sub3) + (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (tinj13 e1b)) + (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> + \e2 -> eunPair e2 $ \_ e2a e2b -> + EPair ext (inj23 e2a) (tinj23 e2b)) + (\e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)))) + +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + + +--------------------------------- ACCUMULATORS --------------------------------- + +fromArrayValId :: Maybe (ValId t) -> Maybe Int +fromArrayValId (Just (VIArr i _)) = Just i +fromArrayValId _ = Nothing + +accumPromote :: forall dt env sto proxy r. + proxy dt + -> Descr env sto + -> (forall stoRepl envPro. + (Select env stoRepl "merge" ~ '[]) + => Descr env stoRepl + -- ^ A revised environment description that switches + -- arrays (used in the OccEnv) that are currently on + -- "merge" storage, to "accum" storage. + -> SList STy envPro + -- ^ New entries on top of the original dual environment, + -- that house the accumulators for the promoted arrays in + -- the original environment. + -> Subenv (Select env sto "merge") envPro + -- ^ The promoted entries were merge entries in the + -- original environment. + -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) + -- ^ All entries that were accumulators are still + -- accumulators. + -> VarMap Int (D2AcE (Select env stoRepl "accum")) + -- ^ Accumulator map for _only_ the the newly allocated + -- accumulators. + -> (forall shbinds. + SList STy shbinds + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) + -- ^ A weakening that converts a computation in the + -- revised environment to one in the original environment + -- extended with some accumulators. + -> r) + -> r +accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of + -- Accumulators are left as-is + SAccum -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + envpro + prosub + (SEYesR accrevsub) + (VarMap.sink1 accumMap) + (\shbinds -> + autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) + (#pro :++: #d :++: #shb :++: #acc :++: #tl) + .> WCopy (wf shbinds) + .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + (#d :++: #shb :++: #acc :++: #tl) + (#acc :++: (#d :++: #shb :++: #tl))) + + SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + (SENo prosub) + accrevsub + accumMap' + wf + + -- Values with "merge" storage are promoted to an accumulator in envPro + _ -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + (t `SCons` envpro) + (SEYesR prosub) + (SENo accrevsub) + (let accumMap' = VarMap.sink1 accumMap + in case fromArrayValId vid of + Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' + Nothing -> accumMap') + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- Discrete values are left as-is, nothing to do + SDiscr -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + prosub + accrevsub + accumMap + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STLEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + +---------------------------- RETURN TRIPLE FROM CHAD --------------------------- + +data Ret env0 sto sd t = + forall shbinds tapebinds contribs. + Ret (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (Ex (Append shbinds (D1E env0)) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) + +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = + forall shbinds tapebinds. + SingleRet + (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (RetPair env0 sto (D1E env0) shbinds tapebinds pair) + +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +-- -> Subenv shbinds tapebinds +-- -> Ex (Append shbinds (D1E env0)) (D1 t) +-- -> SubenvS (D2E (Select env0 sto "merge")) contribs +-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +-- -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where + RetPair :: forall sd t contribs -- existentials + env0 sto env shbinds tapebinds. -- universals + Ex (Append shbinds env) (D1 t) + -> SubenvS (D2E (Select env0 sto "merge")) contribs + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) + -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) + +data Rets env0 sto env list = + forall shbinds tapebinds. + Rets (Bindings Ex env shbinds) + (Subenv shbinds tapebinds) + (SList (RetPair env0 sto env shbinds tapebinds) list) +deriving instance Show (Rets env0 sto env list) + +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) + +weakenRetPair :: SList STy shbinds -> env :> env' + -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair +weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 + +weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list +weakenRets w (Rets binds tapesub list) = + let (binds', _) = weakenBindingsE w binds + in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) + +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. + Descr env0 sto + -> SList f b1 -> SList f b2 + -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) + | Refl <- lemAppendAssoc @b2 @b1 @env = + RetPair e1 sub + (weakenExpr (autoWeak + (#d (auto1 @sd) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) + e2) + +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list +retConcat _ SNil = Rets BTop SETop SNil +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) + | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs + <- weakenRets (sinkWithBindings e0) (retConcat descr list) + , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) + , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) + = Rets (bconcat e0 binds) + (subenvConcat subtape subtape2) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) + sub + (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) + (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) + subtape subtape2) + pairs)) + +freezeRet :: Descr env sto + -> Ret env sto (D2 t) t + -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = + let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 + e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) + in letBinds e0' $ + EPair ext + (weakenExpr wInsertD2Ac e1) + (ELet ext (weakenExpr (autoWeak library + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) + (#shbinds :++: #d :++: #d2ace :++: #tl)) + e2') $ + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) + + +---------------------------- THE CHAD TRANSFORMATION --------------------------- + +drev :: forall env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2 t) sd + -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = + \e -> + Ret BTop + SETop + (drevPrimal des e) + (subenvNone (d2e (select SMerge des))) + (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = + \e -> + case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + e1 + sub' + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) + (inj2 (ENil ext)) + (inj1 (weakenExpr (WCopy WSink) e2))) + } + +drev des accumMap sd = \case + EVar _ t i -> + case conv2Idx des i of + Idx2Ac accI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (let ty = applySparse sd (d2M t) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + + Idx2Me tupI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvOnehot (d2e (select SMerge des)) tupI sd) + (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) + + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + ELet _ (rhs :: Expr _ _ a) body + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs + , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds + , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) + -> + subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in + Ret (bconcat (rhs0 `bpush` rhs1) body0') + (subenvConcat subtapeRHS subtapeBody) + (weakenExpr wbody0' body1) + subBoth + (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #tl) + (#d :++: (#body :++: #rhs) :++: #tl)) + body2) $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ + plus_RHS_Body + (EVar ext (contribTupTy des subRHS) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) + + EPair _ a b + | SpPair sd1 sd2 <- sd + , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil + , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EPair ext a1 b1) + subBoth + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (contribTupTy des subA) (IS IZ)) + (EVar ext (contribTupTy des subB) IZ)) + + EFst _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e + , STPair t1 _ <- typeOf e -> + Ret e0 + subtape + (EFst ext e1) + sub + (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ + weakenExpr (WCopy WSink) e2) + + ESnd _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e + , STPair _ t2 <- typeOf e -> + Ret e0 + subtape + (ESnd ext e1) + sub + (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + -- Don't need to handle ENil, because its cotangent is always absent! + -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) + + EInl _ t2 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInl ext (d1 t2) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) + (inj2 $ ENil ext) + (inj1 $ weakenExpr (WCopy WSink) e2) + (EError ext (contribTupTy des sub') "inl<-dinr")) + + EInr _ t1 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInr ext (d1 t1) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) + (inj2 $ ENil ext) + (EError ext (contribTupTy des sub') "inr<-dinl") + (inj1 $ weakenExpr (WCopy WSink) e2)) + + ECase _ e (a :: Expr _ _ t) b + | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e + , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge + , let (bindids1, bindids2) = validSplitEither (extOf e) + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 + <- drevScoped des accumMap t1 storage1 bindids1 sd a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 + <- drevScoped des accumMap t2 storage2 bindids2 sd b + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e + , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , let tapeA = tapeTy subtapeListA + , let tapeB = tapeTy subtapeListB + , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) + (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) + (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) + , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 + , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 + , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) + , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) + , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) + , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) + , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) + , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) + -> + subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> + subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> + Ret (e0 `bpush` ECase ext e1 + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) + (SEYesR subtapeE) + (EFst ext (EVar ext tPrimal IZ)) + subOut + (elet + (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListA + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #ta0 subtapeListA + &. #prea0 prerebinds + &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #ta0 :++: #tl) + (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) + a2) $ + EPair ext (sAB_A $ EFst ext (evar IZ)) + (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListB + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #tb0 subtapeListB + &. #preb0 prerebinds + &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #tb0 :++: #tl) + (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) + b2) $ + EPair ext (sAB_B $ EFst ext (evar IZ)) + (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ + plus_AB_E + (EFst ext (evar IZ)) + (ELet ext (ESnd ext (evar IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) + + EConst _ t val -> + Ret BTop + SETop + (EConst ext t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EOp _ op e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> + case d2op op of + Linear d2opfun -> + Ret e0 + subtape + (d1op op e1) + sub + (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy WSink) e2)) + Nonlinear d2opfun -> + Ret (e0 `bpush` e1) + (SEYesR subtape) + (d1op op $ EVar ext (d1 (typeOf e)) IZ) + sub + (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) + (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy (wSinks' @[_,_])) e2)) + + ECustom _ _ tb _ srce pr du a b + -- allowed to ignore a2 because 'a' is the part of the input that is inactive + | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> + case isDense (d2M (typeOf srce)) sd of + Just Refl -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) + `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) + (SEYesR (SENo (SENo (SENo bsubtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + + Nothing -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) + (SEYesR (SENo (SENo bsubtape))) + (EFst ext (EVar ext (typeOf pr) IZ)) + bsub + (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape + ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent + (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) + (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ + ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + + ERecompute _ e -> + deleteUnused (descrList des) (occCountAll e) $ \usedSub -> + let smallE = unsafeWeakenWithSubenv usedSub e in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> + let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in + Ret (collectBindings (desD1E des) subD1eUsed) + (subenvAll (desD1E usedDes)) + (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) + (subenvCompose subMergeUsed' sub) + (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + weakenExpr + (autoWeak (#d (auto1 @sd) + &. #shbinds (bindingsBinds e0) + &. #tape (subList (bindingsBinds e0) subtape) + &. #d1env (desD1E usedDes) + &. #tl' (d2ace (select SAccum usedDes)) + &. #tl (d2ace (select SAccum des))) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) + (#shbinds :++: #d :++: #d1env :++: #tl)) + e2) + } + + EError _ t s -> + Ret BTop + SETop + (EError ext (d1 t) s) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EConstArr _ n t val -> + Ret BTop + SETop + (EConstArr ext n t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) + | SpArr @_ @sdElt sdElt <- sd + , let eltty = typeOf ef + , shty :: STy shty <- tTup (sreplicate ndim tIx) + , Refl <- indexTupD1Id ndim -> + drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + let library = #ix (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d (auto1 @sdElt) + &. #tape (auto1 @e_tape) + &. #pro (d2ace provars) + &. #d2acEnv (d2ace (select SAccum des)) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim e_tape)) in + Ret (proPrimalBinds + `bpush` weakenExpr (wSinks (d1e provars)) + (EBuild ext ndim + (drevPrimal des she) + (letBinds e0 $ + EPair ext e1 e1tape)) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) + (SEYesR (SENo (subenvAll (d1e provars)))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in + ESnd ext $ + wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ + EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + e2) + + EMap _ ef (earr :: Expr _ _ (TArr n a)) + | SpArr sdElt <- sd + , let STArr ndim t1 = typeOf earr + t2 = typeOf ef -> + drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 -> + case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds + ttape = typeOf ef1tape + library = #d1env (desD1E des) + &. #a0 (bindingsBinds ea0) + &. #atapebinds (subList (bindingsBinds ea0) easubtape) + &. #propr (d1e provars) + &. #x (d1 t1 `SCons` SNil) + &. #parr (STArr ndim (d1 t1) `SCons` SNil) + &. #tapearr (STArr ndim ttape `SCons` SNil) + &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil) + &. #dy (applySparse sdElt (d2 t2) `SCons` SNil) + &. #tape (ttape `SCons` SNil) + &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil) + &. #d2acEnv (d2ace (select SAccum des)) + &. #pro (d2ace provars) + in + subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a -> + Ret (bconcat ea0 proPrimalBinds' + `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1 + `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env)) + (letBinds ef0 $ + EPair ext ef1 ef1tape)) + (EVar ext (STArr ndim (d1 t1)) IZ) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ)) + (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars)))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ))) + subfa + (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout) $ + emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $ + elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv) + (#tape :++: #dy :++: #dytape :++: #pro :++: layout)) + ef2) + (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ)) + (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $ + plus_f_a + (ESnd ext (evar IZ)) + (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout)) + (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ)) + ea2))) + } + + EFold1Inner _ commut origef ex₀ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #px₀ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) + &. #pzi (auto1 @(ZeroInfo (D2 elt))) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #x₀abinds (bindingsBinds bindsx₀a) + &. #fbinds (bindingsBinds ef0) + &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace provars) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' + `bpush` weakenExpr wOverPrimalBindings ex₀1 + `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) + `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in + weakenExpr (autoWeak library (#xy :++: #d1env) layout) + (letBinds ef0 $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + ef1 + (EPair ext + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ)) + ef1tape))) + (EVar ext (d1 eltty) (IS (IS IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) + (EFst ext (EVar ext primalTy IZ)) + subx₀af + (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ + plus_x₀a_f + (plus_x₀_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) + ex₀2) + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) + + EUnit _ e + | SpArr sdElt <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> + Ret e0 + subtape + (EUnit ext e1) + sub + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EReplicate1Inner _ en e + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd + , let STArr ndim eltty = typeOf e -> + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } + + EIdx0 _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e + , STArr _ t <- typeOf e -> + Ret e0 + subtape + (EIdx0 ext e1) + sub + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" + {- + EIdx1 _ e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil + , STArr (SS n) eltty <- typeOf e -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) + (SEYesR (SENo subtape)) + (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) + (weakenExpr (WSink .> WSink) ei1)) + sub + (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + -} + + EIdx _ e ei + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e + , Refl <- indexTupD1Id n + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (typeOf e1) IZ) + `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + + EShape _ e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e + , Refl <- indexTupD1Id n -> + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) + (ENil ext) + + ESum1Inner _ e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , STArr (SS n) t <- typeOf e -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) + (SEYesR (SENo subtape)) + (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) + sub + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e + + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } + + ENothing{} -> err_unsupported "ENothing" + EJust{} -> err_unsupported "EJust" + EMaybe{} -> err_unsupported "EMaybe" + ELNil{} -> err_unsupported "ELNil" + ELInl{} -> err_unsupported "ELInl" + ELInr{} -> err_unsupported "ELInr" + ELCase{} -> err_unsupported "ELCase" + + EWith{} -> err_accum + EZero{} -> err_monoid + EDeepZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" + + contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) + contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `bpush` e1 + `bpush` extremum (EVar ext at IZ)) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + +data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) + +data RetScoped env0 sto a s sd t = + forall shbinds tapebinds contribs sa. + RetScoped + (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds + (Subenv (Append shbinds '[D1 a]) tapebinds) + (Ex (Append shbinds (D1E (a : env0))) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + -- ^ merge contributions to the _enclosing_ merge environment + (Sparse (D2 a) sa) + -- ^ contribution to the argument + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) sa))) + -- ^ the merge contributions, plus the cotangent to the argument + -- (if there is any) +deriving instance Show (RetScoped env0 sto a s sd t) + +drevScoped :: forall a s env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> STy a -> Storage s -> Maybe (ValId a) + -> Sparse (D2 t) sd + -> Expr ValId (a : env) t + -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of + SMerge + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + case sub of + SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 + SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) + + SAccum + | chcSmartWith ?config + , Just (VIArr i _) <- argids + , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap + , Just Refl <- testEquality foundTy (STAccum (d2M argty)) + , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr + , Refl <- lemAppendNil @tapebinds -> + -- Our contribution to the binding's cotangent _here_ is zero (absent), + -- because we're contributing to an earlier binding of the same value + -- instead. + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ + let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in + ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ + weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: #body :++: #tl)) + (EPair ext e2 (ENil ext)) + + | let accumMap' = case argids of + Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) + _ -> VarMap.sink1 accumMap + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> + let library = #d (auto1 @sd) + &. #p (auto1 @(D1 a)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des)) + in + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ + let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + weakenExpr (autoWeak library + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: (#body :++: #p) :++: #tl)) + e2 + + SDiscr + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) + => Descr env sto + -> VarMap Int (D2AcE (Select env sto "accum")) + -> (STy a, Storage s) + -> Sparse (D2 t) dt + -> Expr ValId (a : env) t + -> (forall provars shbinds tape d2a'. + SList STy provars + -> Subenv (D2E (Select env sto "merge")) (D2E provars) + -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) + -> Bindings Ex (D1 a : D1E env) shbinds + -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) + -> Ex (Append shbinds (D1 a : D1E env)) tape + -> Sparse (D2 a) d2a' + -> (forall env' b. + D1E provars :> env' + -> Ex (Append (D2AcE provars) env') b + -> Ex ( env') (TPair b (Tup (D2E provars)))) + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> r) + -> r +drevLambda des accumMap (argty, argsto) sd origef k = + let t = typeOf origef in + deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + case prf1 prodes argty argsto of { Refl -> + case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> + let library = #fbinds (bindingsBinds ef0) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #arg (d1 argty `SCons` SNil) + &. #d (applySparse sd (d2 t) `SCons` SNil) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #efPrerebinds efPrerebinds in + k envPro + (subenvD2E (subenvCompose subMergeUsed proSub)) + mergePrimalBindings + (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) + (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #arg :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) + argSp + (\wpro1 body -> + uninvertTup (d2e envPro) (typeOf body) $ + makeAccumulators wpro1 envPro $ + body) + (letBinds (efRebinds IZ) $ + weakenExpr + (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + (getSparseArg ef2)) + }} + where + extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) + => proxy env sto -> proxy2 a -> Storage s + -- if s == "merge", this simplifies to SubenvS '[D2 a] t' + -- if s == "discr", this simplifies to SubenvS '[] t' + -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' + -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id + extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + + prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s + -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" + prf1 _ _ SMerge = Refl + prf1 _ _ SDiscr = Refl + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) + = mapExt (const ext) e diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs new file mode 100644 index 0000000..6f25f11 --- /dev/null +++ b/src/CHAD/Drev/Accum.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD.Drev and CHAD.Drev.Top. +module CHAD.Drev.Accum where + +import CHAD.AST +import CHAD.Data +import CHAD.Drev.Types +import CHAD.AST.Env + + +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +-- The weakening is necessary because we need to initialise the created +-- accumulators with zeros. Those zeros are deep and need full primals. This +-- means, in the end, that primals corresponding to environment entries +-- promoted to an accumulator with accumPromote in CHAD need to be stored for +-- the dual. +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs index 4c287d7..5a90303 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/Drev/EnvDescr.hs @@ -7,18 +7,18 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.EnvDescr where +module CHAD.Drev.EnvDescr where import Data.Kind (Type) import Data.Some import GHC.TypeLits (Symbol) -import Analysis.Identity (ValId(..)) -import AST.Env -import AST.Types -import AST.Weaken -import CHAD.Types -import Data +import CHAD.Analysis.Identity (ValId(..)) +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types type Storage :: Symbol -> Type @@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env' -> r) -> r subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of - SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e) + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) subDescr (des `DPush` (_, _, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of @@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des select s@SAccum (DPush des (_, _, SDiscr)) = select s des select s@SMerge (DPush des (_, _, SDiscr)) = select s des select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Top.hs b/src/CHAD/Drev/Top.hs index 261ddfe..65b4dee 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Drev/Top.hs @@ -1,6 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} @@ -8,18 +8,20 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.Top where +module CHAD.Drev.Top where -import Analysis.Identity -import AST -import AST.SplitLets -import AST.Weaken.Auto -import CHAD -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap +import CHAD.Analysis.Identity +import CHAD.AST +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.SplitLets +import CHAD.AST.Weaken.Auto +import CHAD.Data +import CHAD.Data.VarMap qualified as VarMap +import CHAD.Drev +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types type family MergeEnv env where @@ -41,39 +43,25 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) - else k (des `DPush` (t, Nothing, SMerge)) - -d1Identity :: STy t -> D1 t :~: t -d1Identity = \case - STNil -> Refl - STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STMaybe t | Refl <- d1Identity t -> Refl - STArr _ t | Refl <- d1Identity t -> Refl - STScal _ -> Refl - STAccum{} -> error "Accumulators not allowed in input program" - -d1eIdentity :: SList STy env -> D1E env :~: env -d1eIdentity SNil = Refl -d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) reassembleD2E :: Descr env sto + -> D1E env :> env' -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) -reassembleD2E DTop _ = ENil ext -reassembleD2E (des `DPush` (_, _, SAccum)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) - (ESnd ext (EVar ext (typeOf e) IZ)))) - (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (_, _, SMerge)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) - (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) - (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) @@ -83,21 +71,22 @@ chad config env (term :: Ex env t) let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ - makeAccumulators (select SAccum descr) $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr VarMap.empty term')) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) - (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) - (ESnd ext (EFst ext (EVar ext tvar IZ))))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term') + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') where term' = identityAnalysis env (splitLets term) diff --git a/src/CHAD/Types.hs b/src/CHAD/Drev/Types.hs index 974669d..367a974 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Drev/Types.hs @@ -1,10 +1,12 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.Types where +module CHAD.Drev.Types where -import AST.Types -import Data +import CHAD.AST.Accum +import CHAD.AST.Types +import CHAD.Data type family D1 t where @@ -18,11 +20,11 @@ type family D1 t where type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) + D2 (TPair a b) = TPair (D2 a) (D2 b) D2 (TEither a b) = TLEither (D2 a) (D2 b) D2 (TLEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TMaybe (TArr n (D2 t)) + D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t type family D2s t where @@ -60,11 +62,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env d2M :: STy t -> SMTy (D2 t) d2M STNil = SMTNil -d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) +d2M (STPair a b) = SMTPair (d2M a) (d2M b) d2M (STEither a b) = SMTLEither (d2M a) (d2M b) d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) d2M (STMaybe t) = SMTMaybe (d2M t) -d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) +d2M (STArr n t) = SMTArr n (d2M t) d2M (STScal t) = case t of STI32 -> SMTNil STI64 -> SMTNil @@ -95,6 +97,8 @@ data CHADConfig = CHADConfig chcCaseArrayAccum :: Bool , -- | Introduce top-level arguments containing arrays in accumulator mode. chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool } deriving (Show) @@ -103,12 +107,14 @@ defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False , chcArgArrayAccum = False + , chcSmartWith = False } chcSetAccum :: CHADConfig -> CHADConfig chcSetAccum c = c { chcLetArrayAccum = True , chcCaseArrayAccum = True - , chcArgArrayAccum = True } + , chcArgArrayAccum = True + , chcSmartWith = True } ------------------------------------ LEMMAS ------------------------------------ @@ -116,3 +122,32 @@ chcSetAccum c = c { chcLetArrayAccum = True indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) indexTupD1Id SZ = Refl indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs index 8476712..019119c 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -1,14 +1,14 @@ {-# LANGUAGE GADTs #-} -module CHAD.Types.ToTan where +module CHAD.Drev.Types.ToTan where import Data.Bifunctor (bimap) -import Array -import AST.Types -import CHAD.Types -import Data -import ForwardAD -import Interpreter.Rep +import CHAD.Array +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.Interpreter.Rep toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) @@ -19,9 +19,7 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal STEither t1 t2 -> case der of Nothing -> bimap (zeroTan t1) (zeroTan t2) primal Just d -> case (primal, d) of @@ -34,14 +32,12 @@ toTan typ primal der = case typ of (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) _ -> error "Primal and cotangent disagree on LEither alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t -> case der of - Nothing -> arrayMap (zeroTan t) primal - Just d - | arrayShape primal == arrayShape d -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Example.hs b/src/CHAD/Example.hs index d3f6d0d..34ff889 100644 --- a/src/Example.hs +++ b/src/CHAD/Example.hs @@ -7,26 +7,44 @@ {-# LANGUAGE TypeApplications #-} {-# OPTIONS -Wno-unused-imports #-} -module Example where - -import Array -import AST -import AST.Pretty -import AST.UnMonoid -import CHAD -import CHAD.Top -import ForwardAD -import Interpreter -import Language -import Simplify +module CHAD.Example where import Debug.Trace -import Example.Types + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Data +import CHAD.Drev +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.Interpreter +import CHAD.Language as L +import CHAD.Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) +pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +pipeline config term + | Dict <- styKnown (d2 (typeOf term)) = + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ chad' config knownEnv $ + simplifyFix $ term + +-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures +pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () +pipeline' config term + | Dict <- styKnown (d2 (typeOf term)) = + pprintExpr (pipeline config term) + + bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) @@ -162,8 +180,18 @@ neuralGo = ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of - (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - _ -> undefined + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) + +-- The build body uses free variables in a non-linear way, so their primal +-- values are required in the dual of the build. Thus, compositionally, they +-- are stored in the tape from each individual lambda invocation. This results +-- in n copies of y and z, where only one copy would have sufficed. +exUniformFree :: Ex '[R, I64] R +exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ + let_ #y (#x * 2) $ + let_ #z (#x * 3) $ + idx0 $ sum1i $ + build1 #n $ #i :-> #y * #z + toFloat_ #i diff --git a/src/Example/GMM.hs b/src/CHAD/Example/GMM.hs index 206e534..18641e8 100644 --- a/src/Example/GMM.hs +++ b/src/CHAD/Example/GMM.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeApplications #-} -module Example.GMM where +module CHAD.Example.GMM where -import Example.Types -import Language +import CHAD.Data (SList(..)) +import CHAD.Example.Types +import CHAD.Language diff --git a/src/Example/Types.hs b/src/CHAD/Example/Types.hs index d63159b..1e2f72d 100644 --- a/src/Example/Types.hs +++ b/src/CHAD/Example/Types.hs @@ -1,8 +1,8 @@ {-# LANGUAGE DataKinds #-} -module Example.Types where +module CHAD.Example.Types where -import AST -import Data +import CHAD.AST +import CHAD.Data type R = TScal TF64 diff --git a/src/ForwardAD.hs b/src/CHAD/ForwardAD.hs index b353def..0ae88ce 100644 --- a/src/ForwardAD.hs +++ b/src/CHAD/ForwardAD.hs @@ -4,24 +4,25 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD where +module CHAD.ForwardAD where import Data.Bifunctor (bimap) +import Data.Foldable (fold) import System.IO.Unsafe -- import Debug.Trace --- import AST.Pretty +-- import CHAD.AST.Pretty -import Array -import AST -import Compile -import Data -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep +import CHAD.Array +import CHAD.AST +import CHAD.Compile +import CHAD.Data +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep --- | Tangent along a type (coincides with cotangent for these types) +-- | Tangent along a type (coincides with the cotangent, t'CHAD.Drev.Types.D2', for these types) type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) @@ -89,7 +90,7 @@ tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x -tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x +tanScalars (STArr _ t) x = fold $ arrayMap (tanScalars t) x tanScalars (STScal STI32) _ = [] tanScalars (STScal STI64) _ = [] tanScalars (STScal STF32) x = [realToFrac x] @@ -254,8 +255,10 @@ makeFwdADArtifactInterp env expr = in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) {-# NOINLINE makeFwdADArtifactCompile #-} -makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) -makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr) +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) +makeFwdADArtifactCompile env expr = do + (fun, output) <- compile (dne env) (dfwdDN expr) + return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) diff --git a/src/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs index a6d5ec8..540ec2b 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/CHAD/ForwardAD/DualNumbers.hs @@ -1,11 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -- I want to bring various type variables in scope using type annotations in @@ -14,14 +13,14 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD.DualNumbers ( +module CHAD.ForwardAD.DualNumbers ( dfwdDN, DN, DNS, DNE, dn, dne, ) where -import AST -import Data -import ForwardAD.DualNumbers.Types +import CHAD.AST +import CHAD.Data +import CHAD.ForwardAD.DualNumbers.Types dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) @@ -153,10 +152,11 @@ dfwdDN = \case (EConstArr ext n t x) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e - pairty = (STPair (STScal t) (STScal t)) + pairty = STPair (STScal t) (STScal t) in scalTyCase t (ELet ext (dfwdDN e) $ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) @@ -168,6 +168,9 @@ dfwdDN = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) + EReshape _ n esh e + | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) @@ -190,12 +193,17 @@ dfwdDN = \case EWith{} -> err_accum EAccum{} -> err_accum + EDeepZero{} -> err_monoid EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" + err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" deriv_extremum :: ScalIsNumeric t ~ True => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) @@ -219,4 +227,4 @@ dfwdDN = \case (EFst ext (EVar ext tIxN (IS IZ))))))) (ESnd ext (EVar ext t2 (IS IZ))) (zeroScalarConst t)))) - (EMaximum1Inner ext (dfwdDN e)) + (extremum (dfwdDN e)) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs index dcacf5f..5d5dd9e 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs @@ -1,10 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD.DualNumbers.Types where +module CHAD.ForwardAD.DualNumbers.Types where -import AST.Types -import Data +import CHAD.AST.Types +import CHAD.Data -- | Dual-numbers transformation diff --git a/src/Interpreter.hs b/src/CHAD/Interpreter.hs index 803a24a..6410b5b 100644 --- a/src/Interpreter.hs +++ b/src/CHAD/Interpreter.hs @@ -5,38 +5,40 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Interpreter ( +module CHAD.Interpreter ( interpret, interpretOpen, Value(..), ) where import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) +import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.Int (Int64) import Data.IORef +import Data.Tuple (swap) import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace -import Array -import AST -import AST.Pretty -import Data -import Interpreter.Rep +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Interpreter.Rep newtype AcM s a = AcM { unAcM :: IO a } @@ -111,13 +113,16 @@ interpret'Rec env = \case EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) + EMap _ a b -> do + let STArr _ t = typeOf b + arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b EFold1Inner _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e @@ -141,6 +146,50 @@ interpret'Rec env = \case sh `ShCons` n = arrayShape arr numericIsNum t $ return $ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EReshape _ dim esh e -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh + arr <- interpret' env e + return $ arrayReshape sh arr + EZip _ a b -> do + arr1 <- interpret' env a + arr2 <- interpret' env b + let sh = arrayShape arr1 + when (sh /= arrayShape arr2) $ + error "Interpreter: mismatched shapes in EZip" + return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) @@ -158,14 +207,17 @@ interpret'Rec env = \case initval <- interpret' env e1 withAccum t (typeOf e2) initval $ \accum -> interpret' (V (STAccum t) accum `SCons` env) e2 - EAccum _ t p e1 e2 e3 -> do + EAccum _ t p e1 sp e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparse t p accum idx val + accumAddSparseD t p accum idx sp val EZero _ t ezi -> do zi <- interpret' env ezi return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b @@ -216,6 +268,19 @@ zeroM typ zi = case typ of STF32 -> 0.0 STF64 -> 0.0 +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + addM :: SMTy t -> Rep t -> Rep t -> Rep t addM typ a b = case typ of SMTNil -> () @@ -239,7 +304,7 @@ addM typ a b = case typ of | otherwise -> error "Plus of inconsistently shaped arrays" SMTScal sty -> numericIsNum sty $ a + b -onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a onehotM SAPHere _ _ val = val onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) @@ -256,15 +321,6 @@ withAccum t _ initval f = AcM $ do val <- readAc t accum return (out, val) -newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) -newAcZero typ zi = case typ of - SMTNil -> return () - SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi - SMTLEither{} -> newIORef Nothing - SMTMaybe _ -> newIORef Nothing - SMTArr _ t -> arrayMapM (newAcZero t) zi - SMTScal sty -> numericIsNum sty $ newIORef 0 - newAcDense :: SMTy a -> Rep a -> IO (RepAc a) newAcDense typ val = case typ of SMTNil -> return () @@ -274,26 +330,10 @@ newAcDense typ val = case typ of SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val -newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of - (_, SAPHere) -> newAcDense typ val - - (SMTPair t1 t2, SAPFst prj') -> - (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx) - (SMTPair t1 t2, SAPSnd prj') -> - (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val - - (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - - (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - - (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx - onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = arrayShape ziarr @@ -309,54 +349,67 @@ readAc typ val = case typ of SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val -accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () -accumAddDense typ ref val = case typ of - SMTNil -> return () - SMTPair t1 t2 -> do - accumAddDense t1 (fst ref) (fst val) - accumAddDense t2 (snd ref) (snd val) - SMTLEither{} -> - case val of - Nothing -> return () - Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 - Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - SMTMaybe{} -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - SMTArr _ t1 -> - forM_ [0 .. arraySize ref - 1] $ \i -> - accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) - SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) - -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (_, SAPHere) -> accumAddDense typ ref val +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val - (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val - (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val (SMTLEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") (SMTLEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") (SMTMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) (SMTArr n t1, SAPArrIdx prj') -> - let ((arrindex', ziarr), idx') = idx + let (arrindex', idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = arrayShape ziarr + arrsh = arrayShape ref linindex = toLinearIndex arrsh arrindex - in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> + case val of + Nothing -> return () + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> + case val of + Nothing -> return () + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) +-- TODO: makeval is always 'error' now. Simplify? realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = -- Try modifying what's already in ref. The 'join' makes the snd @@ -405,3 +458,11 @@ ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y diff --git a/src/Interpreter/Accum.hs b/src/CHAD/Interpreter/Accum.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/Accum.hs +++ b/src/CHAD/Interpreter/Accum.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/AccumOld.hs b/src/CHAD/Interpreter/AccumOld.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/AccumOld.hs +++ b/src/CHAD/Interpreter/AccumOld.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs index 1682303..fadc6be 100644 --- a/src/Interpreter/Rep.hs +++ b/src/CHAD/Interpreter/Rep.hs @@ -3,7 +3,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -module Interpreter.Rep where +module CHAD.Interpreter.Rep where import Control.DeepSeq import Data.Coerce (coerce) @@ -12,10 +12,10 @@ import Data.Foldable (toList) import Data.IORef import GHC.Exts (withDict) -import Array -import AST -import AST.Pretty -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.Data type family Rep t where diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs new file mode 100644 index 0000000..6621eef --- /dev/null +++ b/src/CHAD/Language.hs @@ -0,0 +1,423 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Language ( + -- * Named expressions + fromNamed, + NExpr, NFun, + + -- * Functions + lambda, + body, + inline, + (.$), + + -- * Basic language constructs + let_, + pair, fst_, snd_, nil, + inl, inr, case_, + nothing, just, maybe_, + + -- * Array operations + constArr_, + build1, build2, build, + map_, + fold1i, fold1i', + sum1i, + unit, + replicate1i, + maximum1i, minimum1i, + reshape, + fold1iD1, fold1iD1', + fold1iD2, + + -- * Scalar operations + -- | Note that 'NExpr' is also an instance of some numeric classes like 'Num' and 'Floating'. + const_, + idx0, + (!), + shape, + length_, + error_, + (.==), (.<), (CHAD.Language..>), (.<=), (.>=), + not_, and_, or_, + mod_, round_, toFloat_, idiv, + + -- * Control flow + if_, + + -- * Special operations + custom, + recompute, + with, accum, accumS, + oper, oper2, + + -- * Helper types + (:->)(..), + + -- * Reexports + TIx, + Lookup, + Ex, + Ty(..), + SNat(..), Nat(..), N0, N1, N2, N3, +) where + +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.Language.AST + + +-- | Helper type, used for e.g. 'case_' and 'build'. +data a :-> b = a :-> b + deriving (Show) +infixr 0 :-> + + +-- | See 'fromNamed' for a usage example. +body :: NExpr env t -> NFun env env t +body = NBody + +-- | See 'fromNamed' for a usage example. +lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda = NLam + +-- | Inline a function here, with the given list of expressions as arguments. +-- While this is a normal 'SList', the @params@ list is reversed from the +-- natural argument order of the function; the '(.$)' helper operator serves to +-- "fix" the order. +-- +-- @ +-- let fun = 'lambda' \@(TScal TF64) #x $ 'lambda' \@(TScal TBool) #b $ 'body' $ if_ #b #x (#x + 1) +-- in 'inline' fun ('SNil' .$ 16 .$ 'const_' True) +-- @ +-- +-- Note that no 'const_' is needed for the @16@, because 'NExpr' implements +-- 'Num'. +inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t +inline = inlineNFun + +-- | Helper for constructing the argument list for 'inline'; +-- @(.$) = flip 'SCons'@. See 'inline'. +(.$) :: SList f list -> f a -> SList f (a : list) +(.$) = flip SCons + + +-- | The first 'Var' argument is the left-hand side of this let-binding. For example: +-- +-- @ +-- 'fromNamed' $ 'lambda' \@(TScal TI64) #a $ 'body' $ +-- 'let_' #x (#a + 1) $ +-- #x * #a +-- @ +-- +-- This produces an expression of type @'Ex' '[TScal TI64] (TScal TI64)@ that +-- corresponds to the Haskell code @\\a -> let x = a + 1 in x * a@. +let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t +let_ = NELet + +pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) +pair = NEPair + +fst_ :: NExpr env (TPair a b) -> NExpr env a +fst_ = NEFst + +snd_ :: NExpr env (TPair a b) -> NExpr env b +snd_ = NESnd + +nil :: NExpr env TNil +nil = NENil + +inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) +inl = NEInl knownTy + +inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) +inr = NEInr knownTy + +-- | A @case@ expression on @Either@s. For example, the following expression +-- will evaluate to 10 + 1 = 11: +-- +-- @ +-- 'case_' ('inl' 10) +-- (#x :-> #x + 1) +-- (#y :-> #y * 2) +-- @ +case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c +case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 + +nothing :: KnownTy a => NExpr env (TMaybe a) +nothing = NENothing knownTy + +just :: NExpr env a -> NExpr env (TMaybe a) +just = NEJust + +-- | Analogue of the 'Prelude.maybe' function in the Haskell Prelude: +-- +-- @ +-- 'maybe_' 2 (#x :-> #x * 3) (...) +-- @ +-- +-- will return 2 if @(...)@ is @Nothing@ and @x + 3@ if it is @Just x@. +maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b +maybe_ a (v :-> b) c = NEMaybe a v b c + +-- | To construct 'Array' values, see "CHAD.Array". +constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) +constArr_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConstArr knownNat ty x + +-- | Special case of 'build' for 1-dimensional arrays. This produces the array +-- [0.0, 1.0, 2.0]: +-- +-- @ +-- 'build1' 3 (#i :-> 'toFloat_' #i) +-- @ +build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) +build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) + +-- | Special case of 'build' for 2-dimensional arrays. +build2 :: NExpr env TIx -> NExpr env TIx + -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) + -> NExpr env (TArr (S (S Z)) t) +build2 a1 a2 (v1 :-> v2 :-> b) = + NEBuild (SS (SS SZ)) + (pair (pair nil a1) a2) + #idx + (let_ v1 (snd_ (fst_ #idx)) $ + let_ v2 (NEDrop SZ (snd_ #idx)) $ + NEDrop (SS (SS SZ)) b) + +-- | General n-dimensional elementwise array constructor. A 3-dimensional index +-- looks like @((((), i1), i2), i3)@; other dimensionalities are analogous. The +-- innermost dimension (i.e. whose index variable varies the fastest in the +-- standard memory layout) is the right-most index, i.e. @i3@ in 3D example. To +-- create a 10-by-10 table of (row, column) pairs: +-- +-- @ +-- 'build' ('SS' ('SS' 'SZ')) ('pair' ('pair' 'nil' 10) 10) (#i :-> #j :-> 'pair' #i #j) +-- @ +build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) +build n a (v :-> b) = NEBuild n a v b + +map_ :: forall n a b env name. (KnownNat n, KnownTy a) + => (Var name a :-> NExpr ('(name, a) : env) b) + -> NExpr env (TArr n a) -> NExpr env (TArr n b) +map_ (v :-> a) b = NEMap v a b + +-- | Fold over the innermost dimension of an array, thus reducing its dimensionality by one. +fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | The underlying AST constructor for a fold takes a function with /one/ +-- argument: a pair of inputs. 'fold1i'' directly returns this AST constructor +-- in case it is helpful for testing. The 'fold1i' function is a convenience +-- wrapper around 'fold1i''. +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 + +sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +sum1i e = NESum1Inner e + +unit :: NExpr env t -> NExpr env (TArr Z t) +unit = NEUnit + +replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) +replicate1i n a = NEReplicate1Inner n a + +maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +maximum1i e = NEMaximum1Inner e + +minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +minimum1i e = NEMinimum1Inner e + +reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) +reshape = NEReshape + +-- | 'fold1iD1'' with a curried combination function. +fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | Primal of a fold. Not supported in the input program for reverse differentiation. +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 + +-- | Reverse pass of a fold. Not supported in the input program for reverse differentiation. +fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) + -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) +fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 + +const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) +const_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty x + +idx0 :: NExpr env (TArr Z t) -> NExpr env t +idx0 = NEIdx0 + +-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) +-- (.!) = NEIdx1 +-- infixl 9 .! + +-- | Index an array. Note that the index is a tuple, just like the argument to +-- the function in 'build'. To index a 2-dimensional array @a@ at row @i@ and +-- column @j@, write @a '!' 'pair' ('pair' 'nil' i) j@. +(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx +infixl 9 ! + +shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) +shape = NEShape + +-- | Convenience special case of 'shape' for single-dimensional arrays. +length_ :: NExpr env (TArr N1 t) -> NExpr env TIx +length_ e = snd_ (shape e) + +oper :: SOp a t -> NExpr env a -> NExpr env t +oper = NEOp + +oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t +oper2 op a b = NEOp op (pair a b) + +error_ :: KnownTy t => String -> NExpr env t +error_ s = NEError knownTy s + +-- | Specify a custom reverse derivative for a subexpression. Morally, the type +-- of this combinator should be read as follows: +-- +-- @ +-- custom :: (a -> b -> t) -- normal semantics +-- -> (D1 a -> D1 b -> (D1 t, tape)) -- forward pass +-- -> (tape -> D2 t -> D2 b) -- reverse pass +-- -> a -> b -- arguments +-- -> t -- result +-- @ +-- +-- In normal evaluation, or when forward-differentiating, the first argument is +-- taken and the second and third are ignored. When reverse-differentiating +-- using CHAD, however, the /first/ argument is ignored and the second and +-- third arguments are respectively put in the forward and the reverse passes +-- of the derivative program. The @tape@ value may be used to remember primals +-- for the reverse pass. +-- +-- This combinator allows for "inactive" and "active" inputs to the operation; +-- derivatives to the "inactive" input are not propagated. The active input +-- (whose derivatives /are/ propagated) has type @b@; the inactive input has +-- type @a@. +-- +-- No accumulators are allowed inside @a@, @b@ and @tape@. +custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) + -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) + -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) + -> NExpr env a -> NExpr env b + -> NExpr env t +custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = + NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 + +-- | Semantically the identity, but when reverse differentiating using CHAD, +-- the contained expression is recomputed in the reverse pass. This is a +-- light-weight form of checkpointing, with the goal of reducing the number +-- primal values being stored and thus reducing memory use and memory traffic. +-- +-- Note that free variables of the contained expression do still need to be +-- stored, as we do need to be able to recompute the expression in the reverse +-- pass. +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + +-- | Introduce an accumulator. The initial value is not allowed to be sparse! +-- See 'CHAD.AST.EWith'. Not supported in the input program for reverse +-- differentiation. +with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) +with a (n :-> b) = NEWith (knownMTy @t) a n b + +-- | Accumulate to an accumulator. Not supported in the input program for +-- reverse differentiation. +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +-- | Accumulate to an accumulator with additional sparsity. Not supported in +-- the input program for reverse differentiation. +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c + + +(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .== b = oper (OEq knownScalTy) (pair a b) +infix 4 .== + +(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .< b = oper (OLt knownScalTy) (pair a b) +infix 4 .< + +(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>) = flip (.<) +infix 4 .> + +(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .<= b = oper (OLe knownScalTy) (pair a b) +infix 4 .<= + +(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>=) = flip (.<=) +infix 4 .>= + +not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) +not_ = oper ONot + +and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +and_ = oper2 OAnd +infixr 3 `and_` + +or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +or_ = oper2 OOr +infixr 2 `or_` + +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` + +-- | The first alternative is the True case; the second is the False case. +if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t +if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) + +round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) +round_ = oper ORound64 + +toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) +toFloat_ = oper OToFl64 + +idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) +idiv = oper2 (OIDiv knownScalTy) +infixl 7 `idiv` diff --git a/src/Language/AST.hs b/src/CHAD/Language/AST.hs index 7e074df..502a2b3 100644 --- a/src/Language/AST.hs +++ b/src/CHAD/Language/AST.hs @@ -4,7 +4,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -12,19 +14,22 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module Language.AST where +module CHAD.Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) -import Array -import AST -import CHAD.Types -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types +-- | A named expression: variables have names, not De Bruijn indices. +-- Otherwise essentially identical to 'Expr'. type NExpr :: [(Symbol, Ty)] -> Ty -> Type data NExpr env t where -- lambda calculus @@ -49,12 +54,23 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) + + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) + -> NExpr env t1 + -> NExpr env (TArr (S n) t1) + -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) + NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) + -> NExpr env (TArr (S n) b) + -> NExpr env (TArr n t2) + -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) @@ -76,7 +92,7 @@ data NExpr env t where -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a @@ -85,11 +101,23 @@ data NExpr env t where NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -type family Lookup name env where - Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup name ('(name, t) : env) = t - Lookup name (_ : env) = Lookup name env +-- | Look up the type of a name in a named environment. +type Lookup name env = Lookup1 (name == "_") name env +-- | This curious stack of type families is used instead of normal pattern +-- matching so the decidable boolean predicate "==" is used. This means that +-- introducing evidence of @(name1 == name2) ~ False@ may allow a certain +-- lookup to reduce even if the names in question are not statically known. +-- This flexibility is used with e.g. 'assertSymbolDistinct' in +-- 'CHAD.Language.fold1i'. +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env @@ -141,10 +169,20 @@ data NEnv env where NTop :: NEnv '[] NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) --- | First (outermost) parameter on the outside, on the left. --- * env: environment of this function (grows as you go deeper inside lambdas) --- * env': environment of the body of the function --- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list +-- | A named /function/. These can be used in only two ways: they can be +-- converted to an unnamed 'Expr' using 'fromNamed', and they can be inlined +-- using 'CHAD.Language.inline'. +-- +-- * @env@: environment of this function (smaller than @env'@; grows as you descend under lambdas) +-- * @env'@: environment of the body of the function +-- +-- For example, a function @(\\(x :: a) (y :: b) -> _ :: c)@ with two free +-- variables, @u :: t1@ and @v :: t2@, would be represented with a value of the +-- following type: +-- +-- @ +-- NFun '['("v", t2), '("u", t1)] '['("y", b), '("x", a), '("v", t2), '("u", t1)] c +-- @ data NFun env env' t where NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t NBody :: NExpr env' t -> NFun env' env' t @@ -160,6 +198,41 @@ envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t inlineNFun fun args = NEUnnamed (fromNamed fun) args +-- | Convert a named function to an unnamed expression with free variables, +-- ready for consumption by the rest of this library. The function must be +-- closed (meaning that the function as a whole cannot have free variables), +-- and the arguments of the function are realised as free variables of the +-- resulting expression. Typical usage looks as follows: +-- +-- @ +-- {-# LANGUAGE OverloadedLabels #-} +-- import CHAD.Language +-- 'fromNamed' $ 'CHAD.Language.lambda' \@(TScal TF64) #x $ 'CHAD.Language.lambda' \@(TScal TI64) #i $ 'CHAD.Language.body' $ #x + 'CHAD.Language.toFloat_' #i +-- :: 'Ex' '[TScal TI64, TScal TF64] (TScal TF64) +-- @ +-- +-- The rest of the library generally considers expressions with free variables +-- to stand in for "functions", by considering the free variables as the +-- function's inputs. +-- +-- Note that while environments normally grow to the right (e.g. in type theory +-- notation), as they as type-level lists here, they grow to the /left/. This +-- is why the second (innermost) argument of the example, @i@, ends up at the +-- head of the environment of the constructed expression. +-- +-- __Type applications__: The type applications to 'CHAD.Language.lambda' above +-- are good practice, but not always necessary; if GHC can infer the type of +-- the argument from the body of the expression, the type application is +-- unnecessary. +-- +-- __Variables__: The major element of syntactic sugar in this module is using +-- OverloadedLabels for variable names. Variables are represented in 'NExpr' +-- (and thus 'NFun') using the 'Var' type; you should never have to manually +-- construct a 'Var'. Instead, 'Var' implements 'IsLabel' and as such can be +-- produced with the syntax @#name@, where "name" is the name of the variable. +-- This syntax produces a polymorphic variable reference whose (embedded) type +-- is left to GHC's type inference engine using a 'KnownTy' constraint. See +-- also 'CHAD.Language.let_'. fromNamed :: NFun '[] env t -> Ex (UnName env) t fromNamed = fromNamedFun NTop @@ -198,12 +271,17 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEMap n a b -> EMap ext (lambda val n a) (go b) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) + NEReshape n a b -> EReshape ext n (go a) (go b) + + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) + NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) @@ -221,7 +299,7 @@ fromNamedExpr val = \case NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s @@ -260,3 +338,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k diff --git a/src/Lemmas.hs b/src/CHAD/Lemmas.hs index 31a43ed..55ef042 100644 --- a/src/Lemmas.hs +++ b/src/CHAD/Lemmas.hs @@ -4,7 +4,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE AllowAmbiguousTypes #-} -module Lemmas (module Lemmas, (:~:)(Refl)) where +module CHAD.Lemmas (module CHAD.Lemmas, (:~:)(Refl)) where import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) diff --git a/src/Simplify.hs b/src/CHAD/Simplify.hs index d963b7e..ea253d6 100644 --- a/src/Simplify.hs +++ b/src/CHAD/Simplify.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE QuasiQuotes #-} @@ -10,7 +12,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Simplify ( +module CHAD.Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, ) where @@ -19,15 +21,16 @@ import Control.Monad (ap) import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) import Debug.Trace -import AST -import AST.Count -import AST.Pretty -import Data -import Simplify.TH +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.AST.UnMonoid (acPrjCompose) +import CHAD.Data +import CHAD.Simplify.TH data SimplifyConfig = SimplifyConfig @@ -81,22 +84,28 @@ runSM (SM f) = first getAny (f id) smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) smReconstruct core = SM (\ctx -> (Any False, ctx core)) -tellActed :: SM tenv tt env t () -tellActed = SM (\_ -> (Any True, ())) +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) -- more convenient in practice -acted :: SM tenv tt env t a -> SM tenv tt env t a +acted :: ActedMonad m => m a -> m a acted m = tellActed >> m within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) -acted' :: (Any, a) -> (Any, a) -acted' (_, x) = (Any True, x) - -liftActed :: (Any, a) -> SM tenv tt env t a -liftActed pair = SM $ \_ -> pair - simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify' expr | scLogging ?config = do @@ -167,15 +176,30 @@ simplify'Rec = \case ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 (ELet _ rhs body) acc -> + EAccum _ t p e1 sp (ELet _ rhs body) acc -> acted $ simplify' $ ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) -- let () = e in () ~> e ELet _ e1 (ENil _) | STNil <- typeOf e1 -> acted $ simplify' e1 + -- map (\_ -> x) e ~> build (shape e) (\_ -> x) + EMap _ e1 e2 + | Occ Zero Zero <- occCount IZ e1 + , STArr n _ <- typeOf e2 -> + acted $ simplify' $ + EBuild ext n (EShape ext e2) $ + subst (\_ t' -> \case IZ -> error "Unused variable was used" + IS i -> EVar ext t' (IS i)) + e1 + + -- vertical fusion + EMap _ e1 (EMap _ e2 e3) -> + acted $ simplify' $ + EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3 + -- projection down-commuting EFst _ (ECase _ e1 e2 e3) -> acted $ simplify' $ @@ -183,10 +207,23 @@ simplify'Rec = \case ESnd _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (ESnd ext e2) (ESnd ext e3) + EFst _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (EFst ext e1) (EFst ext e2) e3 + ESnd _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 -- TODO: more array indexing - EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) - EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 + EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 + EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1 + EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) + EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 + + -- TODO: more array shape + EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 + EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2) + EShape _ (EReplicate1Inner _ en earr) -> acted $ simplify' (EPair ext (EShape ext earr) en) -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) @@ -216,23 +253,40 @@ simplify'Rec = \case acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules - EAccum _ t p e1 e2 acc -> do - e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 - e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 - acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1' e2') + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') (acted $ return (ENil ext)) - (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) - (\e -> acted $ return e) - (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -272,12 +326,17 @@ simplify'Rec = \case ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> [simprec| EBuild ext n *a *b |] + EMap _ a b -> [simprec| EMap ext *a *b |] EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] EUnit _ e -> [simprec| EUnit ext *e |] EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] + EReshape _ n a b -> [simprec| EReshape ext n *a *b |] + EZip _ a b -> [simprec| EZip ext *a *b |] + EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] + EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> [simprec| EIdx0 ext *e |] EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] @@ -296,20 +355,13 @@ simplify'Rec = \case e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) pure (EWith ext t e1' e2') - EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e - EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b + -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |] + -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |] + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s -cheapExpr :: Expr x env t -> Bool -cheapExpr = \case - EVar{} -> True - ENil{} -> True - EConst{} -> True - EFst _ e -> cheapExpr e - ESnd _ e -> cheapExpr e - EUnit _ e -> cheapExpr e - _ -> False - -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. hasAdds :: Expr x env t -> Bool @@ -332,12 +384,17 @@ hasAdds = \case ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b + EMap _ a b -> hasAdds a || hasAdds b EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e + EReshape _ _ a b -> hasAdds a || hasAdds b + EZip _ a b -> hasAdds a || hasAdds b + EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e @@ -347,8 +404,9 @@ hasAdds = \case EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b ERecompute _ e -> hasAdds e - EAccum _ _ _ _ _ _ -> True + EAccum _ _ _ _ _ _ _ -> True EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -367,51 +425,161 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True -data OneHotTerm env p a b where - OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' -simplifyOneHotTerm :: OneHotTerm env p a b - -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) - -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) - -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do - val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 - case val1' of - EZero{} -> kzero - EOneHot _ t2 prj2 idx2 val2 - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do - tellActed -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k - _ -> case prj1 of - SAPHere -> ktriv val1 - _ -> k (OneHotTerm t1 prj1 idx1 val1) +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht + +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) recogniseMonoid _ e@EOneHot{} = return e -recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case - (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2) - (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' - (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' (a', b') -> return $ EPair ext a' b' recogniseMonoid typ@(SMTLEither t1 t2) expr = case expr of - ELNil{} -> acted' $ return $ EZero ext typ (ENil ext) - ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e - ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e _ -> return expr recogniseMonoid typ@(SMTMaybe t1) expr = case expr of - ENothing{} -> acted' $ return $ EZero ext typ (ENil ext) - EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e _ -> return expr recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = - acted' $ do + acted $ do e' <- recogniseMonoid t e return $ ELet ext e' $ @@ -420,59 +588,33 @@ recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = (ENil ext)) (EVar ext (fromSMTy t) IZ) recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of - (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (SMTPair a _, SAPFst prj1') -> - concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SMTPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - - (SMTLEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (SMTLEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (SMTMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) - (SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) where -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) go t SAPHere _ e = makeZeroInfo t e go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) go SMTLEither{} _ _ _ = ENil ext go SMTMaybe{} _ _ _ = ENil ext go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext diff --git a/src/Simplify/TH.hs b/src/CHAD/Simplify/TH.hs index 2e0076a..4af5394 100644 --- a/src/Simplify/TH.hs +++ b/src/CHAD/Simplify/TH.hs @@ -1,9 +1,9 @@ {-# LANGUAGE TemplateHaskellQuotes #-} -module Simplify.TH (simprec) where +module CHAD.Simplify.TH (simprec) where import Data.Bifunctor (first) import Data.Char -import Data.List (foldl1') +import Data.List (foldl', foldl1') import Language.Haskell.TH import Language.Haskell.TH.Quote import Text.ParserCombinators.ReadP diff --git a/src/Util/IdGen.hs b/src/CHAD/Util/IdGen.hs index 3f6611d..d4fd945 100644 --- a/src/Util/IdGen.hs +++ b/src/CHAD/Util/IdGen.hs @@ -1,6 +1,6 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -module Util.IdGen where +module CHAD.Util.IdGen where import Control.Monad.Fix import Control.Monad.Trans.State.Strict diff --git a/src/Language.hs b/src/Language.hs deleted file mode 100644 index 7a780a0..0000000 --- a/src/Language.hs +++ /dev/null @@ -1,229 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitForAll #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} -module Language ( - fromNamed, - NExpr, - Ex, - module Language, - module AST.Types, - module Data, - Lookup, -) where - -import Array -import AST -import AST.Types -import CHAD.Types -import Data -import Language.AST - - -data a :-> b = a :-> b - deriving (Show) -infixr 0 :-> - - -body :: NExpr env t -> NFun env env t -body = NBody - -lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t -lambda = NLam - -inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t -inline = inlineNFun - --- To be used to construct the argument list for 'inline'. --- --- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y --- > in inline fun (SNil .$ 16 .$ 26) -(.$) :: SList f list -> f a -> SList f (a : list) -(.$) = flip SCons - - -let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -let_ = NELet - -pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) -pair = NEPair - -fst_ :: NExpr env (TPair a b) -> NExpr env a -fst_ = NEFst - -snd_ :: NExpr env (TPair a b) -> NExpr env b -snd_ = NESnd - -nil :: NExpr env TNil -nil = NENil - -inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) -inl = NEInl knownTy - -inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) -inr = NEInr knownTy - -case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c -case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 - -nothing :: KnownTy a => NExpr env (TMaybe a) -nothing = NENothing knownTy - -just :: NExpr env a -> NExpr env (TMaybe a) -just = NEJust - -maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b -maybe_ a (v :-> b) c = NEMaybe a v b c - -constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) -constArr_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConstArr knownNat ty x - -build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) -build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) - -build2 :: NExpr env TIx -> NExpr env TIx - -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) - -> NExpr env (TArr (S (S Z)) t) -build2 a1 a2 (v1 :-> v2 :-> b) = - NEBuild (SS (SS SZ)) - (pair (pair nil a1) a2) - #idx - (let_ v1 (snd_ (fst_ #idx)) $ - let_ v2 (NEDrop SZ (snd_ #idx)) $ - NEDrop (SS (SS SZ)) b) - -build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) -build n a (v :-> b) = NEBuild n a v b - -map_ :: forall n a b env name. (KnownNat n, KnownTy a) - => (Var name a :-> NExpr ('(name, a) : env) b) - -> NExpr env (TArr n a) -> NExpr env (TArr n b) -map_ (v :-> a) b - | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) = - let_ #arg b $ - build knownNat (shape #arg) $ #i :-> - let_ v (#arg ! #i) $ - NEDrop (SS SZ) (NEDrop (SS SZ) a) - -fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 - -sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -sum1i e = NESum1Inner e - -unit :: NExpr env t -> NExpr env (TArr Z t) -unit = NEUnit - -replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) -replicate1i n a = NEReplicate1Inner n a - -maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -maximum1i e = NEMaximum1Inner e - -minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -minimum1i e = NEMinimum1Inner e - -const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) -const_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty x - -idx0 :: NExpr env (TArr Z t) -> NExpr env t -idx0 = NEIdx0 - --- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) --- (.!) = NEIdx1 --- infixl 9 .! - -(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t -(!) = NEIdx -infixl 9 ! - -shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -shape = NEShape - -length_ :: NExpr env (TArr N1 t) -> NExpr env TIx -length_ e = snd_ (shape e) - -oper :: SOp a t -> NExpr env a -> NExpr env t -oper = NEOp - -oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t -oper2 op a b = NEOp op (pair a b) - -error_ :: KnownTy t => String -> NExpr env t -error_ s = NEError knownTy s - -custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) - -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) - -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) - -> NExpr env a -> NExpr env b - -> NExpr env t -custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = - NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 - -recompute :: NExpr env a -> NExpr env a -recompute = NERecompute - -with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) -with a (n :-> b) = NEWith (knownMTy @t) a n b - -accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownMTy p a b c - - -(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .== b = oper (OEq knownScalTy) (pair a b) -infix 4 .== - -(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .< b = oper (OLt knownScalTy) (pair a b) -infix 4 .< - -(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>) = flip (.<) -infix 4 .> - -(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .<= b = oper (OLe knownScalTy) (pair a b) -infix 4 .<= - -(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>=) = flip (.<=) -infix 4 .>= - -not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -not_ = oper ONot - -and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -and_ = oper2 OAnd -infixr 3 `and_` - -or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -or_ = oper2 OOr -infixr 2 `or_` - -mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) -mod_ = oper2 (OMod knownScalTy) -infixl 7 `mod_` - --- | The first alternative is the True case; the second is the False case. -if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t -if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) - -round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) -round_ = oper ORound64 - -toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) -toFloat_ = oper OToFl64 - -idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) -idiv = oper2 (OIDiv knownScalTy) -infixl 7 `idiv` diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs index e0dc4b3..5ca0f38 100644 --- a/test-framework/Test/Framework.hs +++ b/test-framework/Test/Framework.hs @@ -1,63 +1,109 @@ +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingVia #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module Test.Framework ( TestTree, testGroup, - testGroupCollapse, + groupSetCollapse, testProperty, - withResource, - withResource', runTests, defaultMain, Options(..), + -- * Resources + withResource, + withResource', + TestCtx, + outputWarningText, + -- * Compatibility TestName, ) where -import Control.Monad (forM, when) -import Control.Monad.Trans.State.Strict +import Control.Concurrent (setNumCapabilities, forkIO, forkOn, killThread) +import Control.Concurrent.MVar +import Control.Concurrent.STM +import Control.Exception (SomeException, throw, try, throwIO) +import Control.Monad (forM, when, forM_) import Control.Monad.IO.Class +import Data.IORef import Data.List (isInfixOf, intercalate) -import Data.Maybe (isJust, mapMaybe, fromJust) +import Data.Maybe (mapMaybe, fromJust) +import Data.Monoid (All(..), Any(..), Sum(..)) +import Data.PQueue.Prio.Min qualified as PQ import Data.String (fromString) import Data.Time.Clock -import System.Environment +import GHC.Conc (getNumProcessors) +import GHC.Generics (Generic, Generically(..)) +import System.Console.ANSI qualified as ANSI +import System.Console.Concurrent (outputConcurrent) +import System.Console.Regions +import System.Environment (getArgs, getProgName) import System.Exit import System.IO (hFlush, hPutStrLn, stdout, stderr, hIsTerminalDevice) import Text.Read (readMaybe) -import qualified Hedgehog as H -import qualified Hedgehog.Internal.Config as H -import qualified Hedgehog.Internal.Property as H -import qualified Hedgehog.Internal.Report as H -import qualified Hedgehog.Internal.Runner as H -import qualified Hedgehog.Internal.Seed as H.Seed +import Hedgehog qualified as H +import Hedgehog.Internal.Config qualified as H +import Hedgehog.Internal.Property qualified as H +import Hedgehog.Internal.Report qualified as H +import Hedgehog.Internal.Runner qualified as H +import Hedgehog.Internal.Seed qualified as H.Seed + +-- TODO: with GHC 9.12 we have tryWithContext and rethrowIO, which is better for rethrowing exceptions + + +type TestName = String data TestTree - = Group Bool String [TestTree] - | forall a. Resource (IO a) (a -> IO ()) (a -> TestTree) + = Group GroupOpts String [TestTree] + | forall a. Resource String ((?testCtx :: TestCtx) => IO a) ((?testCtx :: TestCtx) => a -> IO ()) (a -> TestTree) + -- ^ Name is not specified by user, but inherited from the tree below | HP String H.Property -type TestName = String +data TestCtx = TestCtx + { tctxOutput :: String -> IO () } + +outputWarningText :: (?testCtx :: TestCtx) => String -> IO () +outputWarningText = tctxOutput ?testCtx + +-- Not exported because a Resource is not supposed to have a name in the first place +treeName :: TestTree -> String +treeName (Group _ name _) = name +treeName (Resource name _ _ _) = name +treeName (HP name _) = name + +data GroupOpts = GroupOpts + { goCollapse :: Bool } + deriving (Show) + +defaultGroupOpts :: GroupOpts +defaultGroupOpts = GroupOpts False testGroup :: String -> [TestTree] -> TestTree -testGroup = Group False +testGroup = Group defaultGroupOpts -testGroupCollapse :: String -> [TestTree] -> TestTree -testGroupCollapse = Group True +groupSetCollapse :: TestTree -> TestTree +groupSetCollapse (Group opts name trees) = Group opts { goCollapse = True } name trees +groupSetCollapse _ = error "groupSetCollapse: not called on a Group" --- | The @a -> TestTree@ function must use the @a@ only inside properties: when --- not actually running properties, it will be passed 'undefined'. -withResource :: IO a -> (a -> IO ()) -> (a -> TestTree) -> TestTree -withResource = Resource +-- | The @a -> TestTree@ function must use the @a@ only inside properties: the +-- function will be passed 'undefined' when exploring the test tree (without +-- running properties). +withResource :: ((?testCtx :: TestCtx) => IO a) -> ((?testCtx :: TestCtx) => a -> IO ()) -> (a -> TestTree) -> TestTree +withResource make cleanup fun = Resource (treeName (fun undefined)) make cleanup fun -- | Same caveats as 'withResource'. -withResource' :: IO a -> (a -> TestTree) -> TestTree +withResource' :: ((?testCtx :: TestCtx) => IO a) -> (a -> TestTree) -> TestTree withResource' make fun = withResource make (\_ -> return ()) fun testProperty :: String -> H.Property -> TestTree @@ -66,27 +112,35 @@ testProperty = HP filterTree :: Options -> TestTree -> Maybe TestTree filterTree (Options { optsPattern = pat }) = go [] where - go path (Group collapse name trees) = + go path (Group opts name trees) = case mapMaybe (go (name:path)) trees of [] -> Nothing - trees' -> Just (Group collapse name trees') - go path (Resource make free fun) = + trees' -> Just (Group opts name trees') + go path (Resource inhname make free fun) = case go path (fun undefined) of Nothing -> Nothing - Just _ -> Just $ Resource make free (fromJust . go path . fun) + Just _ -> Just $ Resource inhname make free (fromJust . go path . fun) go path hp@(HP name _) | pat `isInfixOf` renderPath (name:path) = Just hp | otherwise = Nothing renderPath comps = "^" ++ intercalate "/" (reverse comps) ++ "$" +treeNumTests :: TestTree -> Int +treeNumTests (Group _ _ ts) = sum (map treeNumTests ts) +treeNumTests (Resource _ _ _ fun) = treeNumTests (fun undefined) +treeNumTests HP{} = 1 + computeMaxLen :: TestTree -> Int computeMaxLen = go 0 where go :: Int -> TestTree -> Int - go indent (Group True name trees) = maximum (2*indent + length name : map (go (indent+1)) trees) - go indent (Group False _ trees) = maximum (0 : map (go (indent+1)) trees) - go indent (Resource _ _ fun) = go indent (fun undefined) + go indent (Group opts name trees) + -- If we collapse, the name of the group gets prefixed before the final status message after collapsing. + | goCollapse opts = maximum (2*indent + length name : map (go (indent+1)) trees) + -- If we don't collapse, the group name does get printed but without any status message, so it doesn't need to get accounted for in maxlen. + | otherwise = maximum (0 : map (go (indent+1)) trees) + go indent (Resource _ _ _ fun) = go indent (fun undefined) go indent (HP name _) = 2 * indent + length name data Stats = Stats @@ -97,22 +151,21 @@ data Stats = Stats initStats :: Stats initStats = Stats 0 0 -newtype M a = M (StateT Stats IO a) - deriving newtype (Functor, Applicative, Monad, MonadIO) - -modifyStats :: (Stats -> Stats) -> M () -modifyStats f = M (modify f) +modifyStats :: (?stats :: IORef Stats) => (Stats -> Stats) -> IO () +modifyStats f = atomicModifyIORef' ?stats (\s -> (f s, ())) data Options = Options { optsPattern :: String , optsHelp :: Bool , optsHedgehogReplay :: Maybe (H.Skip, H.Seed) , optsHedgehogShrinks :: Maybe Int + , optsParallel :: Bool + , optsVerbose :: Bool } deriving (Show) defaultOptions :: Options -defaultOptions = Options "" False Nothing Nothing +defaultOptions = Options "" False Nothing Nothing False False parseOptions :: [String] -> Options -> Either String Options parseOptions [] opts = pure opts @@ -134,6 +187,8 @@ parseOptions ("--hedgehog-shrinks":arg:args) opts = case readMaybe arg of Just n -> parseOptions args opts { optsHedgehogShrinks = Just n } Nothing -> Left "Invalid argument to '--hedgehog-shrinks'" +parseOptions ("--parallel":args) opts = parseOptions args opts { optsParallel = True } +parseOptions ("--verbose":args) opts = parseOptions args opts { optsVerbose = True } parseOptions (arg:_) _ = Left $ "Unrecognised argument: '" ++ arg ++ "'" printUsage :: IO () @@ -147,7 +202,12 @@ printUsage = do ," test looks like: '^group1/group2/testname$'." ," --hedgehog-replay '{skip} {seed}'" ," Skip to a particular generated Hedgehog test. Should be used" - ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set."] + ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set." + ," --hedgehog-shrinks NUM" + ," Limit the number of shrinking steps." + ," --parallel Run tests in parallel." + ," --verbose Currently only has an effect with --parallel. Also shows OK" + ," and timing for test groups, not only individual tests."] defaultMain :: TestTree -> IO () defaultMain tree = do @@ -165,58 +225,210 @@ runTests options = \tree' -> return (ExitFailure 1) Just tree -> do isterm <- hIsTerminalDevice stdout - let M m = let ?maxlen = computeMaxLen tree - ?istty = isterm - in go 0 id tree starttm <- getCurrentTime - (success, stats) <- runStateT m initStats + statsRef <- newIORef initStats + success <- + let ?stats = statsRef + ?options = options + ?maxlen = computeMaxLen tree + ?istty = isterm + in if optsParallel options + then do nproc <- getNumProcessors + setNumCapabilities nproc + displayConsoleRegions $ + withWorkerPool nproc $ \pool -> do + let ?pool = pool + successVar <- newEmptyMVar + runTreePar Nothing [] [] tree successVar + readMVar successVar + else getAll . seqresAllSuccess <$> runTreeSeq 0 [] tree + stats <- readIORef statsRef endtm <- getCurrentTime - let ?istty = isterm in printStats stats (diffUTCTime endtm starttm) - return (if isJust success then ExitSuccess else ExitFailure 1) + let ?istty = isterm in printStats (treeNumTests tree) stats (diffUTCTime endtm starttm) + return (if success then ExitSuccess else ExitFailure 1) + +-- | Returns when all jobs in this tree have been scheduled. When all jobs are +-- done, the outvar is filled with whether all tests in this tree were +-- successful. +-- The mparregion is the parent region to take over, if any. Having a parent +-- region to take over implies that we are currently executing in a worker and +-- can hence run blocking user code directly in the current thread. +runTreePar :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> [Int] -> [String] -> TestTree -> MVar Bool -> IO () +-- TODO: handle collapse somehow? +runTreePar mparregion revidxlist revpath (Group _ name trees) outvar = do + let path = intercalate "/" (reverse (name : revpath)) + -- outputConcurrent $ "! " ++ path ++ ": Started\n" + + -- If we have exactly one child and we're currently running in a worker, we can continue doing so + mparregion2 <- case trees of + [_] -> return mparregion + _ -> do -- If not, we have to close the region (and implicitly relinquish the worker job) + forM_ mparregion closeConsoleRegion + return Nothing + + starttm <- getCurrentTime + suboutvars <- forM (zip trees [0..]) $ \(tree, idx) -> do + suboutvar <- newEmptyMVar + runTreePar mparregion2 (idx : revidxlist) (name : revpath) tree suboutvar + return suboutvar + + -- outputConcurrent $ "! " ++ path ++ ": Scheduled all\n" + + -- If we took over the parent region then this readMVar will resolve + -- immediately and the forkIO would be unnecessary. Meh. + _ <- forkIO $ do + success <- and <$> traverse readMVar suboutvars + endtm <- getCurrentTime + -- outputConcurrent $ "! " ++ path ++ ": Done\n" + if success && optsVerbose ?options + then let text = path ++ ": " ++ ansiGreen ++ "OK" ++ ansiReset ++ " " ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + in outputConcurrent (text ++ "\n") + else return () + putMVar outvar success + + return () + +runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runResource topmparregion 1 toptree topoutvar where - -- If all tests are successful, returns the number of output lines produced - go :: (?maxlen :: Int, ?istty :: Bool) => Int -> (String -> String) -> TestTree -> M (Maybe Int) - go indent path (Group collapse name trees) = do - liftIO $ putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout - starttm <- liftIO getCurrentTime - mlns <- fmap (fmap sum . sequence) . forM trees $ - go (indent + 1) (path . (name++) . ('/':)) - endtm <- liftIO getCurrentTime - case mlns of - Just lns | collapse, ?istty -> do - let thislen = 2*indent + length name - liftIO $ putStrLn $ concat (replicate (lns+1) "\x1B[A\x1B[2K") ++ "\x1B[G" ++ - replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ - "\x1B[32mOK\x1B[0m" ++ - prettyDuration False (realToFrac (diffUTCTime endtm starttm)) - return (Just 1) - _ -> return mlns - go indent path (Resource make cleanup fun) = do - value <- liftIO make - success <- go indent path (fun value) - liftIO $ cleanup value - return success - go indent path (HP name (H.Property config test)) = do + runResource + :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> Int -> TestTree -> MVar Bool -> IO () + runResource mparregion depth (Resource inhname make cleanup fun) outvar = do + let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")" + path = intercalate "/" (reverse (pathitem : revpath)) + idxlist = reverse revidxlist + let ?testCtx = TestCtx (\str -> + outputConcurrent (ansiYellow ++ "## Warning for " ++ path ++ ":" ++ ansiReset ++ + "\n" ++ str ++ "\n")) + submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do + setConsoleRegion makeRegion ('|' : path ++ " [R] making...") + + try make >>= \case + Left (err :: SomeException) -> do + finishConsoleRegion makeRegion $ + ansiRed ++ "Exception building resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err + putMVar outvar False + Right value -> do + suboutvar <- newEmptyMVar + runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion + + _ <- forkIO $ do + success <- readMVar suboutvar + poolSubmit ?pool idxlist Nothing $ do + cleanupRegion <- openConsoleRegion Linear + setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...") + try (cleanup value) >>= \case + Left (err :: SomeException) -> do + finishConsoleRegion cleanupRegion $ + ansiRed ++ "Exception cleaning up resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err + putMVar outvar False + Right () -> do + closeConsoleRegion cleanupRegion + putMVar outvar success + return () + runResource mparregion _ tree outvar = runTreePar mparregion revidxlist revpath tree outvar + +runTreePar mparregion revidxlist revpath (HP name prop) outvar = do + let path = intercalate "/" (reverse (name : revpath)) + idxlist = reverse revidxlist + + submitOrRunIn mparregion idxlist (Just outvar) $ \region -> do + -- outputConcurrent $ "! " ++ path ++ ": Running" + let prefix = path ++ " [T]" + setConsoleRegion region ('|' : prefix) + let progressHandler report = do + str <- H.renderProgress H.EnableColor (Just (fromString "")) report + setConsoleRegion region ('|' : prefix ++ ": " ++ replace '\n' " " str) + (ok, rendered) <- runHP progressHandler revpath name prop + finishConsoleRegion region (path ++ ": " ++ rendered) + return ok + +submitOrRunIn :: (?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> [Int] -> Maybe (MVar a) -> (ConsoleRegion -> IO a) -> IO () +submitOrRunIn Nothing idxlist outvar fun = + poolSubmit ?pool idxlist outvar (openConsoleRegion Linear >>= fun) +submitOrRunIn (Just reg) _idxlist outvar fun = do + result <- fun reg + forM_ outvar $ \mvar -> putMVar mvar result + +data SeqRes = SeqRes + { seqresHaveWarnings :: Any + , seqresAllSuccess :: All + , seqresNumLines :: Sum Int } + deriving (Generic) + deriving (Semigroup, Monoid) via Generically SeqRes + +-- | If all tests are successful, returns the number of output lines produced +runTreeSeq :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool) + => Int -> [String] -> TestTree -> IO SeqRes +runTreeSeq indent revpath (Group opts name trees) = do + putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout + starttm <- getCurrentTime + res <- fmap mconcat . forM trees $ + runTreeSeq (indent + 1) (name : revpath) + endtm <- getCurrentTime + if not (getAny (seqresHaveWarnings res)) && getAll (seqresAllSuccess res) && goCollapse opts && ?istty + then do let thislen = 2*indent + length name - let outputPrefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' - when ?istty $ liftIO $ putStr outputPrefix >> hFlush stdout + let Sum lns = seqresNumLines res + putStrLn $ concat (replicate (lns+1) (ANSI.cursorUpCode 1 ++ ANSI.clearLineCode)) ++ + ANSI.setCursorColumnCode 0 ++ + replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ + ansiGreen ++ "OK" ++ ansiReset ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + return (mempty { seqresNumLines = 1 }) + else return (res <> mempty { seqresNumLines = 1 }) +runTreeSeq indent revpath (Resource _ make cleanup fun) = do + let path = intercalate "/" (reverse revpath) + outputted <- newIORef False + let ?testCtx = TestCtx (\str -> do + atomicModifyIORef' outputted (\_ -> (True, ())) + putStrLn (ansiYellow ++ "## Warning for " ++ path ++ + ":" ++ ansiReset ++ "\n" ++ str)) + res <- try make >>= \case + Left (err :: SomeException) -> do + putStrLn $ ansiRed ++ "Exception building resource at " ++ path ++ ":" ++ ansiReset + print err + return (mempty { seqresAllSuccess = All False }) + Right value -> do + res <- runTreeSeq indent revpath (fun value) + try (cleanup value) >>= \case + Left (err :: SomeException) -> do + putStrLn $ ansiRed ++ "Exception cleaning up resource at " ++ path ++ ":" ++ ansiReset + print err + return (res { seqresAllSuccess = All False }) + Right () -> return res - let (config', seedfun) = applyHedgehogOptions options config - seed <- seedfun + warnings <- readIORef outputted + return (res <> mempty { seqresHaveWarnings = Any warnings }) +runTreeSeq indent path (HP name prop) = do + let thislen = 2*indent + length name + let prefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' + when ?istty $ putStr prefix >> hFlush stdout + (ok, rendered) <- runHP (outputProgress (?maxlen + 2)) path name prop + putStrLn ((if ?istty then ANSI.clearFromCursorToLineEndCode else prefix) ++ rendered) >> hFlush stdout + return (mempty { seqresAllSuccess = All ok, seqresNumLines = 1 }) - starttm <- liftIO getCurrentTime - report <- liftIO $ H.checkReport config' 0 seed test (outputProgress (?maxlen + 2)) - endtm <- liftIO getCurrentTime +runHP :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int) + => (H.Report H.Progress -> IO ()) + -> [String] + -> String -> H.Property -> IO (Bool, String) +runHP progressPrinter revpath name (H.Property config test) = do + let (config', seedfun) = applyHedgehogOptions ?options config + seed <- seedfun - liftIO $ do - when (not ?istty) $ putStr outputPrefix - printResult report (path name) (diffUTCTime endtm starttm) - hFlush stdout + starttm <- getCurrentTime + report <- H.checkReport config' 0 seed test progressPrinter + endtm <- getCurrentTime - let ok = H.reportStatus report == H.OK - modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats - , statsTotal = 1 + statsTotal stats } - return (if ok then Just 1 else Nothing) + rendered <- renderResult report (intercalate "/" (reverse (name : revpath))) (diffUTCTime endtm starttm) + + let ok = H.reportStatus report == H.OK + modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats + , statsTotal = 1 + statsTotal stats } + return (ok, rendered) applyHedgehogOptions :: MonadIO m => Options -> H.PropertyConfig -> (H.PropertyConfig, m H.Seed) applyHedgehogOptions opts config0 = @@ -232,32 +444,77 @@ outputProgress :: (?istty :: Bool) => Int -> H.Report H.Progress -> IO () outputProgress indent report | ?istty = do str <- H.renderProgress H.EnableColor (Just (fromString "")) report - putStr (replace '\n' " " str ++ "\x1B[" ++ show (indent+1) ++ "G") + putStr (replace '\n' " " str ++ ANSI.setCursorColumnCode indent) hFlush stdout | otherwise = return () -printResult :: (?istty :: Bool) => H.Report H.Result -> String -> NominalDiffTime -> IO () -printResult report path timeTaken = do +renderResult :: H.Report H.Result -> String -> NominalDiffTime -> IO String +renderResult report path timeTaken = do str <- H.renderResult H.EnableColor (Just (fromString "")) report case H.reportStatus report of - H.OK -> putStrLn (ansi "\x1B[K" ++ str ++ prettyDuration False (realToFrac timeTaken)) + H.OK -> return (str ++ prettyDuration False (realToFrac timeTaken)) H.Failed failure -> do let H.Report { H.reportTests = count, H.reportDiscards = discards } = report replayInfo = H.skipCompress (H.SkipToShrink count discards (H.failureShrinkPath failure)) ++ " " ++ show (H.reportSeed report) suffix = "\n Flags to reproduce: `-p '" ++ path ++ "' --hedgehog-replay '" ++ replayInfo ++ "'`" - putStrLn (ansi "\x1B[K" ++ str ++ suffix) - _ -> putStrLn (ansi "\x1B[K" ++ str) + return (str ++ suffix) + _ -> return str -printStats :: (?istty :: Bool) => Stats -> NominalDiffTime -> IO () -printStats stats timeTaken - | statsOK stats == statsTotal stats = do - putStrLn $ ansi "\x1B[32m" ++ "All " ++ show (statsTotal stats) ++ - " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m" +printStats :: (?istty :: Bool) => Int -> Stats -> NominalDiffTime -> IO () +printStats numTests stats timeTaken + | statsOK stats == numTests = do + putStrLn $ ansiGreen ++ "All " ++ show (statsTotal stats) ++ + " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset + | statsOK stats == statsTotal stats = + putStrLn $ ansiRed ++ "Failed (" ++ show (numTests - statsTotal stats) ++ " tests could not run)." ++ + prettyDuration True (realToFrac timeTaken) ++ ansiReset | otherwise = let nfailed = statsTotal stats - statsOK stats - in putStrLn $ ansi "\x1B[31m" ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ - " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m" + in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ " tests" ++ + (if statsTotal stats /= numTests then " (" ++ show (numTests - statsTotal stats) ++ " could not run)" else "") ++ + "." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset + + +newtype WorkerPool k = WorkerPool (TVar (PQ.MinPQueue k (Terminate PoolJob))) +data PoolJob = forall a. PoolJob (IO a) (Maybe (MVar a)) +data Terminate a = Terminate | Value a + deriving (Eq, Ord) -- Terminate sorts before Value + +withWorkerPool :: Ord k => Int -> (WorkerPool k -> IO a) -> IO a +withWorkerPool numWorkers k = do + chan <- newTVarIO PQ.empty + threads <- forM [0..numWorkers-1] (\i -> forkOn i (worker i chan)) + eres <- try (k (WorkerPool chan)) + case eres of + Left (err :: SomeException) -> do + atomically $ writeTVar chan PQ.empty + forM_ threads killThread + throw err + Right res -> do + readTVarIO chan >>= \case + PQ.Empty -> return () + _ -> throwIO (userError "withWorkerPool: computation exited before all jobs were handled") + return res + where + worker :: Ord k => Int -> TVar (PQ.MinPQueue k (Terminate PoolJob)) -> IO () + worker idx chan = do + job <- atomically $ + readTVar chan >>= \case + PQ.Empty -> retry + (_, j) PQ.:< q -> writeTVar chan q >> return j + case job of + Value (PoolJob action mmvar) -> do + -- outputConcurrent $ "[" ++ show idx ++ "] got job\n" + result <- action + forM_ mmvar $ \mvar -> putMVar mvar result + worker idx chan + Terminate -> return () + +poolSubmit :: Ord k => WorkerPool k -> k -> Maybe (MVar a) -> IO a -> IO () +poolSubmit (WorkerPool chan) key mmvar action = + atomically $ modifyTVar chan $ PQ.insert key (Value (PoolJob action mmvar)) + prettyDuration :: Bool -> Double -> String prettyDuration False x | x < 0.5 = "" @@ -273,3 +530,24 @@ replace x ys = concatMap (\y -> if y == x then ys else [y]) ansi :: (?istty :: Bool) => String -> String ansi | ?istty = id | otherwise = const "" + +ansiRed, ansiYellow, ansiGreen, ansiReset :: (?istty :: Bool) => String +ansiRed = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red]) +ansiYellow = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Vivid ANSI.Yellow]) +ansiGreen = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green]) +ansiReset = ansi (ANSI.setSGRCode [ANSI.Reset]) + +-- getTermIsDark :: IO Bool +-- getTermIsDark = do +-- mclr <- ANSI.getLayerColor ANSI.Background +-- case mclr of +-- Nothing -> return True +-- Just (RGB r g b) -> +-- let cvt n = fromIntegral n / fromIntegral (maxBound `asTypeOf` n) +-- in return $ (cvt r + cvt g + cvt b) / 3 < (0.5 :: Double) + +-- ansiRegionBg :: (?istty :: Bool, ?termisdark :: Bool) => String +-- ansiRegionBg +-- | not ?istty = "" +-- | ?termisdark = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.0 0.05 0.1)] +-- | otherwise = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.95 0.95 1.0)] diff --git a/test/Main.hs b/test/Main.hs index 1b83a2e..05597cc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} @@ -11,35 +12,38 @@ {-# LANGUAGE UndecidableInstances #-} module Main where +import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) -import qualified Data.Map.Strict as Map -import qualified Data.Text as T +import Data.Map.Strict qualified as Map +import Data.Text qualified as T import Hedgehog -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range import Test.Framework -import Array -import AST hiding ((.>)) -import AST.Pretty -import AST.UnMonoid -import CHAD.Top -import CHAD.Types -import CHAD.Types.ToTan -import Compile -import qualified Example -import qualified Example.GMM as Example -import Example.Types -import ForwardAD -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep -import Language -import Simplify +import CHAD.Array +import CHAD.AST hiding ((.>)) +import CHAD.AST.Count (pruneExpr) +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Drev.Types.ToTan +import CHAD.Example qualified as Example +import CHAD.Example.GMM qualified as Example +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep +import CHAD.Language +import CHAD.Simplify data TypedValue t = TypedValue (STy t) (Rep t) @@ -63,18 +67,18 @@ simplifyIters iters env | Dict <- envKnown env = SimplIters n -> simplifyN n SimplFix -> simplifyFix --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) -gradientByCHAD simplIters env term input = - let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - (out, grad) = interpretOpen False env input dterm - in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) +-- gradientByCHAD simplIters env term input = +-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term +-- (out, grad) = interpretOpen False env input dterm +-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) -gradientByCHAD' simplIters env term input = - second (second (toTanE env input)) $ - gradientByCHAD simplIters env term input +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) +-- gradientByCHAD' simplIters env term input = +-- second (second (toTanE env input)) $ +-- gradientByCHAD simplIters env term input gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 @@ -92,8 +96,8 @@ extendDN (STMaybe _) Nothing = pure Nothing extendDN (STMaybe t) (Just x) = Just <$> extendDN t x extendDN (STArr _ t) arr = traverse (extendDN t) arr extendDN (STScal sty) x = case sty of - STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) - STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) + STF32 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) + STF64 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) STI32 -> pure x STI64 -> pure x STBool -> pure x @@ -217,8 +221,8 @@ genShape = \n tpl -> do shapeDiv :: Shape n -> DimNames n -> Int -> Shape n shapeDiv ShNil _ _ = ShNil - shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) - shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f)) + shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` max lo (n `div` f) + shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` max lo (n `div` f) shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f) shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f) @@ -240,8 +244,8 @@ genValue topty tpl = case topty of ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of - STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + STF32 -> Value <$> Gen.realFloat (Range.constant (-10) 10) + STF64 -> Value <$> Gen.realFloat (Range.constant (-10) 10) STI32 -> genInt STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] @@ -302,10 +306,12 @@ adTestGen name expr envGenerator = exprS = simplifyFix expr in withCompiled env expr $ \primalfun -> withCompiled env (simplifyFix expr) $ \primalSfun -> - testGroupCollapse name + groupSetCollapse $ testGroup name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS - ,adTestGenChad env envGenerator expr exprS primalSfun] + ,testGroup "chad" + [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun + ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]] adTestGenPrimal :: SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R @@ -336,24 +342,34 @@ adTestGenFwd env envGenerator exprS = diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 -adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) +adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree -adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr +adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env = + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr dtermChadS = simplifyFix dtermChad0 - dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermChadSUS = simplifyFix $ unMonoid dtermChadS + dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 + dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS + dtermSChadSUSP = simplifyFix $ pruneExpr env dtermSChadSUS in - withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> - testProperty "chad" $ property $ do + withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS + when (not (null output)) $ + outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>" + return fun) $ \fwdartifactC -> + withCompiled env dtermSChadSUSP $ \dcompSChadSUSP -> + testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - -- pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) + -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) + let dtermChad20 = simplifyN 20 dtermChad0 + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20)) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20))) + let dtermSChad20 = simplifyN 20 dtermSChad0 + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20)) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input @@ -363,38 +379,54 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS - (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 - (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 - tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS - tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 - tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP - (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) - let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS + (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input) + let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP -- annotate (showEnv (d2e env) gradChad0) -- annotate (showEnv (d2e env) gradChadS) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outCompSChadS closeIsh outPrimal + annotate (ppExpr env dtermSChadSUSP) + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outSChadSUSP closeIsh outPrimal + diff outCompSChadSUSP closeIsh outPrimal let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) - diff tansChad closeIshE' tansFwd - diff tansChadS closeIshE' tansFwd - diff tansSChad closeIshE' tansFwd - diff tansSChadS closeIshE' tansFwd - diff tansCompSChadS closeIshE' tansFwd + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansSChadSUSP closeIshE' tansFwd + diff tansCompSChadSUSP closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) +withCompiled env expr = withResource' $ do + (fun, output) <- compile env expr + when (not (null output)) $ + outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]) gen_gmm = do @@ -409,7 +441,7 @@ gen_gmm = do vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) - vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vgamma <- Gen.realFloat (Range.constant (-10) 10) vm <- Gen.integral (Range.linear 0 5) let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma @@ -435,11 +467,30 @@ gen_neural = do lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) +term_build0 :: Ex '[TArr N0 R] R +term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx + term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_build1_idx :: Ex '[TVec R] R +term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ + build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) + +term_idx_coprod :: Ex '[TVec (TEither R R)] R +term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -502,25 +553,32 @@ tests_Compile = testGroup "Compile" ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ with @(TPair R R) (pair 0.0 0.0) $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TMaybe (TPair R R)) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ + with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $ nil ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil + + ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $ + fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a + + ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $ + let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $ + fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d ] tests_AD :: TestTree @@ -556,9 +614,7 @@ tests_AD = testGroup "AD" ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ - idx0 $ - build SZ (shape #x) $ #idx :-> #x ! #idx + ,adTest "build0" term_build0 ,adTest "build1-sum" term_build1_sum @@ -566,6 +622,27 @@ tests_AD = testGroup "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx + ,adTest "build1-idx" term_build1_idx + + ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#x ! pair nil #i) $ + 3 * fst_ #p + 2 * snd_ #p + + ,adTest "idx-coprod" $ term_idx_coprod + + ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ + let_ #n (snd_ (shape #arr)) $ + let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $ + if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x)) + (pair (inr (3 * #x)) (exp #x)))) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#b ! pair nil #i) $ + case_ (fst_ #p) + (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p) + (#b :-> #b * 4) + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x @@ -607,6 +684,38 @@ tests_AD = testGroup "AD" ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm + + ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree + + ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #n (snd_ #sh * snd_ (fst_ #sh)) $ + idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a + + ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #innern (snd_ #sh) $ + let_ #n (#innern * snd_ (fst_ #sh)) $ + let_ #flata (reshape (SS SZ) (pair nil #n) #a) $ + -- ensure the input array to EReshape is shared + idx0 $ sum1i $ + build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern)) + + ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a + + ,adTest "fold-prod" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y) 1 #a + + ,adTest "fold-freevar" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + let_ #v 2 $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #v) 1 #a + + ,adTestTp "fold-freearr" (C "" 1) $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #a ! pair nil 0) 1 #a + + ,adTest "map" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ sum1i $ map_ (#x :-> 2 * #x) #a ] main :: IO () |
