diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | .hlint.yaml | 17 | ||||
| -rw-r--r-- | README.md | 5 | ||||
| -rw-r--r-- | bench/Main.hs | 71 | ||||
| -rw-r--r-- | chad-fast.cabal | 88 | ||||
| -rw-r--r-- | example/Main.hs | 2 | ||||
| -rw-r--r-- | rules/.gitignore | 1 | ||||
| -rw-r--r-- | src/AST/Count.hs | 170 | ||||
| -rw-r--r-- | src/CHAD/APIv1.hs | 178 | ||||
| -rw-r--r-- | src/CHAD/AST.hs (renamed from src/AST.hs) | 206 | ||||
| -rw-r--r-- | src/CHAD/AST/Accum.hs (renamed from src/AST/Accum.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/AST/Bindings.hs (renamed from src/AST/Bindings.hs) | 19 | ||||
| -rw-r--r-- | src/CHAD/AST/Count.hs | 927 | ||||
| -rw-r--r-- | src/CHAD/AST/Env.hs (renamed from src/AST/Env.hs) | 20 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs (renamed from src/AST/Pretty.hs) | 78 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse.hs (renamed from src/AST/Sparse.hs) | 44 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs (renamed from src/AST/Sparse/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/AST/SplitLets.hs (renamed from src/AST/SplitLets.hs) | 53 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs (renamed from src/AST/Types.hs) | 34 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs (renamed from src/AST/UnMonoid.hs) | 24 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken.hs (renamed from src/AST/Weaken.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken/Auto.hs (renamed from src/AST/Weaken/Auto.hs) | 55 | ||||
| -rw-r--r-- | src/CHAD/Analysis/Identity.hs (renamed from src/Analysis/Identity.hs) | 61 | ||||
| -rw-r--r-- | src/CHAD/Array.hs (renamed from src/Array.hs) | 12 | ||||
| -rw-r--r-- | src/CHAD/Compile.hs (renamed from src/Compile.hs) | 455 | ||||
| -rw-r--r-- | src/CHAD/Compile/Exec.hs (renamed from src/Compile/Exec.hs) | 16 | ||||
| -rw-r--r-- | src/CHAD/Data.hs (renamed from src/Data.hs) | 4 | ||||
| -rw-r--r-- | src/CHAD/Data/VarMap.hs (renamed from src/Data/VarMap.hs) | 13 | ||||
| -rw-r--r-- | src/CHAD/Drev.hs (renamed from src/CHAD.hs) | 494 | ||||
| -rw-r--r-- | src/CHAD/Drev/Accum.hs (renamed from src/CHAD/Accum.hs) | 17 | ||||
| -rw-r--r-- | src/CHAD/Drev/EnvDescr.hs (renamed from src/CHAD/EnvDescr.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/Drev/Top.hs (renamed from src/CHAD/Top.hs) | 32 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs (renamed from src/CHAD/Types.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs (renamed from src/CHAD/Types/ToTan.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/Example.hs (renamed from src/Example.hs) | 55 | ||||
| -rw-r--r-- | src/CHAD/Example/GMM.hs (renamed from src/Example/GMM.hs) | 7 | ||||
| -rw-r--r-- | src/CHAD/Example/Types.hs (renamed from src/Example/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD.hs (renamed from src/ForwardAD.hs) | 29 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers.hs (renamed from src/ForwardAD/DualNumbers.hs) | 23 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers/Types.hs (renamed from src/ForwardAD/DualNumbers/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/Interpreter.hs (renamed from src/Interpreter.hs) | 83 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Accum.hs (renamed from src/Interpreter/Accum.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/AccumOld.hs (renamed from src/Interpreter/AccumOld.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Rep.hs (renamed from src/Interpreter/Rep.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/Language.hs | 423 | ||||
| -rw-r--r-- | src/CHAD/Language/AST.hs (renamed from src/Language/AST.hs) | 127 | ||||
| -rw-r--r-- | src/CHAD/Lemmas.hs (renamed from src/Lemmas.hs) | 2 | ||||
| -rw-r--r-- | src/CHAD/Simplify.hs (renamed from src/Simplify.hs) | 63 | ||||
| -rw-r--r-- | src/CHAD/Simplify/TH.hs (renamed from src/Simplify/TH.hs) | 4 | ||||
| -rw-r--r-- | src/CHAD/Util/IdGen.hs (renamed from src/Util/IdGen.hs) | 2 | ||||
| -rw-r--r-- | src/Language.hs | 233 | ||||
| -rw-r--r-- | test-framework/Test/Framework.hs | 468 | ||||
| -rw-r--r-- | test/Main.hs | 209 |
53 files changed, 3670 insertions, 1256 deletions
@@ -1,3 +1,5 @@ dist-newstyle/ cabal.project.local .ccls-cache/ + +compile_diagnostics_watcher.sh diff --git a/.hlint.yaml b/.hlint.yaml new file mode 100644 index 0000000..7ec649a --- /dev/null +++ b/.hlint.yaml @@ -0,0 +1,17 @@ +- ignore: {name: "Avoid lambda"} +- ignore: {name: "Avoid lambda using `infix`"} +- ignore: {name: "Collapse lambdas"} +- ignore: {name: "Eta reduce"} +- ignore: {name: "Evaluate"} +- ignore: {name: "Redundant $"} +- ignore: {name: "Redundant lambda"} +- ignore: {name: "Use bimap"} +- ignore: {name: "Use camelCase"} +- ignore: {name: "Use const"} +- ignore: {name: "Use forM_"} +- ignore: {name: "Use newtype instead of data"} +- ignore: {name: "Use record patterns"} +- ignore: {name: "Use tuple-section"} +- ignore: {name: "Use unless"} +- ignore: {name: "Use unwords"} +- ignore: {name: "Use void"} diff --git a/README.md b/README.md new file mode 100644 index 0000000..585e3b0 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# chad-fast + +This is work-in-progress research software. If you are interested in techniques +or ideas in this repository, please do [contact me](https://tomsmeding.com), +I'm always happy to chat. diff --git a/bench/Main.hs b/bench/Main.hs index 358ba31..1e8f6f3 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -1,11 +1,14 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where @@ -16,26 +19,31 @@ import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) -import AST -import AST.UnMonoid -import Array -import qualified CHAD (defaultConfig) -import CHAD.Top -import CHAD.Types -import Compile -import Data -import Example -import Example.GMM -import Example.Types -import Interpreter.Rep -import Simplify +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Array +import CHAD.Compile +import CHAD.Data +import CHAD.Drev qualified as CHAD (defaultConfig) +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example +import CHAD.Example.GMM +import CHAD.Example.Types +import CHAD.Interpreter.Rep +import CHAD.Simplify gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) gradCHAD config term = - compile knownEnv $ - simplifyFix $ unMonoid $ simplifyFix $ - ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term + compileStderr knownEnv $ + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ + ELet ext (EConst ext STF64 1.0) $ + chad' config knownEnv $ + simplifyFix term type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where @@ -93,18 +101,19 @@ makeGMMInputs = accumConfig :: CHADConfig accumConfig = chcSetAccum CHAD.defaultConfig -main :: IO () -main = defaultMain - [env (return makeNeuralInputs) $ \inputs -> bgroup "neural" - [env (gradCHAD CHAD.defaultConfig neural) $ \fun -> - bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig neural) $ \fun -> - bench "accum" (nfAppIO fun inputs) - ] - ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm" - [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun -> +bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env)))) + => String -> Ex env R -> SList Value env -> Benchmark +bgroupDefaultAccum name term inputs = + bgroup name + [env (gradCHAD CHAD.defaultConfig term) $ \fun -> bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun -> + ,env (gradCHAD accumConfig term) $ \fun -> bench "accum" (nfAppIO fun inputs) ] + +main :: IO () +main = defaultMain + [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural + ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False) + ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil) ] diff --git a/chad-fast.cabal b/chad-fast.cabal index b7270e4..1eef3ed 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -10,46 +10,49 @@ build-type: Simple library exposed-modules: - Analysis.Identity - Array - AST - AST.Accum - AST.Bindings - AST.Count - AST.Env - AST.Pretty - AST.Sparse - AST.Sparse.Types - AST.SplitLets - AST.Types - AST.UnMonoid - AST.Weaken - AST.Weaken.Auto - CHAD - CHAD.Accum - CHAD.EnvDescr - CHAD.Top - CHAD.Types - CHAD.Types.ToTan - Compile - Compile.Exec - Data - Data.VarMap - Example - Example.GMM - Example.Types - ForwardAD - ForwardAD.DualNumbers - ForwardAD.DualNumbers.Types - Interpreter - -- Interpreter.AccumOld - Interpreter.Rep - Language - Language.AST - Lemmas - Simplify - Simplify.TH - Util.IdGen + -- default ghci module on top + CHAD.Example + + CHAD.Analysis.Identity + CHAD.APIv1 + CHAD.Array + CHAD.AST + CHAD.AST.Accum + CHAD.AST.Bindings + CHAD.AST.Count + CHAD.AST.Env + CHAD.AST.Pretty + CHAD.AST.Sparse + CHAD.AST.Sparse.Types + CHAD.AST.SplitLets + CHAD.AST.Types + CHAD.AST.UnMonoid + CHAD.AST.Weaken + CHAD.AST.Weaken.Auto + CHAD.Compile + CHAD.Compile.Exec + CHAD.Data + CHAD.Data.VarMap + CHAD.Drev + CHAD.Drev.Accum + CHAD.Drev.EnvDescr + CHAD.Drev.Top + CHAD.Drev.Types + CHAD.Drev.Types.ToTan + CHAD.Example.GMM + CHAD.Example.Types + CHAD.ForwardAD + CHAD.ForwardAD.DualNumbers + CHAD.ForwardAD.DualNumbers.Types + CHAD.Interpreter + -- CHAD.Interpreter.AccumOld + CHAD.Interpreter.Rep + CHAD.Language + CHAD.Language.AST + CHAD.Lemmas + CHAD.Simplify + CHAD.Simplify.TH + CHAD.Util.IdGen other-modules: build-depends: base >= 4.19 && < 4.21, @@ -83,7 +86,11 @@ library test-framework exposed-modules: Test.Framework build-depends: base, + ansi-terminal, + concurrent-output, hedgehog, + pqueue, + stm, time, transformers hs-source-dirs: test-framework @@ -98,7 +105,6 @@ test-suite test test-framework, base, containers, - dependent-map, hedgehog, text, transformers, diff --git a/example/Main.hs b/example/Main.hs index 6c36857..28cb7e8 100644 --- a/example/Main.hs +++ b/example/Main.hs @@ -1,6 +1,6 @@ module Main where -import Example +import CHAD.Example main :: IO () diff --git a/rules/.gitignore b/rules/.gitignore index ef5f7e5..2a45cd8 100644 --- a/rules/.gitignore +++ b/rules/.gitignore @@ -2,3 +2,4 @@ *.log *.out *.pdf +rules-?.jpg diff --git a/src/AST/Count.hs b/src/AST/Count.hs deleted file mode 100644 index ca4d7ab..0000000 --- a/src/AST/Count.hs +++ /dev/null @@ -1,170 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Count where - -import Data.Functor.Const -import GHC.Generics (Generic, Generically(..)) - -import AST -import AST.Env -import Data - - -data Count = Zero | One | Many - deriving (Show, Eq, Ord) - -instance Semigroup Count where - Zero <> n = n - n <> Zero = n - _ <> _ = Many -instance Monoid Count where - mempty = Zero - -data Occ = Occ { _occLexical :: Count - , _occRuntime :: Count } - deriving (Eq, Generic) - deriving (Semigroup, Monoid) via Generically Occ - -instance Show Occ where - showsPrec d (Occ l r) = showParen (d > 10) $ - showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r - --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) - --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l Zero) = Occ l Zero -scaleMany (Occ l _) = Occ l Many - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = - getConst . occCountGeneral - (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) - (\(Const o) -> Const o) - (\(Const o1) (Const o2) -> Const (o1 <||> o2)) - (\(Const o) -> Const (scaleMany o)) - - -data OccEnv env where - OccEnd :: OccEnv env -- not necessarily top! - OccPush :: OccEnv env -> Occ -> OccEnv (t : env) - -instance Semigroup (OccEnv env) where - OccEnd <> e = e - e <> OccEnd = e - OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') - -instance Monoid (OccEnv env) where - mempty = OccEnd - -onehotOccEnv :: Idx env t -> Occ -> OccEnv env -onehotOccEnv IZ v = OccPush OccEnd v -onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty - -(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env -OccEnd <||>! e = e -e <||>! OccEnd = e -OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') - -scaleManyOccEnv :: OccEnv env -> OccEnv env -scaleManyOccEnv OccEnd = OccEnd -scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) - -occEnvPop :: OccEnv (t : env) -> OccEnv env -occEnvPop (OccPush o _) = o -occEnvPop OccEnd = OccEnd - -occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv - -occCountGeneral :: forall r env t x. - (forall env'. Monoid (r env')) - => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot - -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env'. r env' -> r env' -> r env') -- ^ alternation - -> (forall env'. r env' -> r env') -- ^ scale-many - -> Expr x env t -> r env -occCountGeneral onehot unpush alter many = go WId - where - go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' - go w = \case - EVar _ _ i -> onehot w i (Occ One One) - ELet _ rhs body -> re rhs <> re1 body - EPair _ a b -> re a <> re b - EFst _ e -> re e - ESnd _ e -> re e - ENil _ -> mempty - EInl _ _ e -> re e - EInr _ _ e -> re e - ECase _ e a b -> re e <> (re1 a `alter` re1 b) - ENothing _ _ -> mempty - EJust _ e -> re e - EMaybe _ a b e -> re a <> re1 b <> re e - ELNil _ _ _ -> mempty - ELInl _ _ e -> re e - ELInr _ _ e -> re e - ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c) - EConstArr{} -> mempty - EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c - ESum1Inner _ e -> re e - EUnit _ e -> re e - EReplicate1Inner _ a b -> re a <> re b - EMaximum1Inner _ e -> re e - EMinimum1Inner _ e -> re e - EConst{} -> mempty - EIdx0 _ e -> re e - EIdx1 _ a b -> re a <> re b - EIdx _ a b -> re a <> re b - EShape _ e -> re e - EOp _ _ e -> re e - ECustom _ _ _ _ _ _ _ a b -> re a <> re b - ERecompute _ e -> re e - EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a _ b e -> re a <> re b <> re e - EZero _ _ e -> re e - EDeepZero _ _ e -> re e - EPlus _ _ a b -> re a <> re b - EOneHot _ _ _ a b -> re a <> re b - EError{} -> mempty - where - re :: Monoid (r env') => Expr x env' t'' -> r env' - re = go w - - re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' - re1 = unpush . go (WSink .> w) - - -deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil OccEnd k = k SETop -deleteUnused (_ `SCons` env) OccEnd k = - deleteUnused env OccEnd $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = - deleteUnused env occenv $ \sub -> - case count of Zero -> k (SENo sub) - _ -> k (SEYesR sub) - -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t -unsafeWeakenWithSubenv = \sub -> - subst (\x t i -> case sinkViaSubenv i sub of - Just i' -> EVar x t i' - Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") - where - sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYesR _) = Just IZ - sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub - sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs new file mode 100644 index 0000000..73d1580 --- /dev/null +++ b/src/CHAD/APIv1.hs @@ -0,0 +1,178 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.APIv1 ( + -- * Expressions and types + Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..), + + -- * Reverse derivatives (Fast CHAD) + vjp, vjp', + D2, D2E, Tup, + CHADConfig(..), + + -- ** Primal type transform + -- | The primal type transform is only important when working with special + -- operations like 'CHAD.Language.custom'. + D1, + + -- * Forward derivatives (dual numbers) + jvp, jvpDN, + Tan, TanS, DN, DNS, DNE, + + -- * Working with expressions + interpret, interpret1, + compile, compile1, + fullSimplify, + SList(..), Value(..), Rep, + KnownEnv(..), KnownTy(..), + SNat(..), +) where + +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Compile qualified as Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter qualified as Interpreter +import CHAD.Simplify +import CHAD.Interpreter.Rep + + +-- | Compute a reverse derivative: a vector-Jacobian product. The type has been +-- simplified with the assumption that 'D1' is the identity. +vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp = vjp' (chcSetAccum defaultConfig) + +-- | Same as 'vjp', but supply CHAD configuration. +vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp' config term + | Dict <- styKnown (d2 (typeOf term)) = + fullSimplify $ + unMonoid . simplifyFix $ -- need to merge onehots and accums for unMonoid to do its work + chad' config knownEnv (simplifyFix term) + +jvpDN :: Ex env t -> Ex (DNE env) (DN t) +jvpDN = dfwdDN + +jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t)) +jvp term + | Dict <- styKnown (tanty (knownTy @s)) + = fullSimplify $ + elet (ezipDN knownTy) $ + elet (weakenExpr (WCopy WClosed) (jvpDN term)) $ + eunzipDN (typeOf term) + where + ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s') + ezipDN STNil = ENil ext + ezipDN (STPair a b) = + EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env a)) + (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env b)) + ezipDN (STEither a b) = + ecase (EVar ext (STEither a b) (IS IZ)) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EInl ext (dn b) (ezipDN a)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr")) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl") + (EInr ext (dn a) (ezipDN b))) + ezipDN (STLEither a b) = + elcase (EVar ext (STLEither a b) (IS IZ)) + (ELNil ext (dn a) (dn b)) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN") + (ELInl ext (dn b) (ezipDN a)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr")) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN") + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl") + (ELInr ext (dn a) (ezipDN b))) + ezipDN (STMaybe t) = + emaybe (EVar ext (STMaybe t) (IS IZ)) + (ENothing ext (dn t)) + (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ)) + (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN") + (EJust ext (ezipDN t))) + ezipDN (STArr n t) = + ezipWith (ezipDN t) + (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ) + ezipDN (STScal st) = case st of + STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ) + STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ) + STI32 -> EVar ext (STScal STI32) (IS IZ) + STI64 -> EVar ext (STScal STI64) (IS IZ) + STBool -> EVar ext (STScal STBool) (IS IZ) + ezipDN STAccum{} = error "jvp: Accumulators not supported in source program" + + eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t')) + eunzipDN STNil = EPair ext (ENil ext) (ENil ext) + eunzipDN (STPair a b) = + eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 -> + eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 -> + EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2) + eunzipDN (STEither a b) = + ecase (EVar ext (STEither (dn a) (dn b)) IZ) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (EInl ext b a1) (EInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (EInr ext a b1) (EInr ext (tanty a) b2)) + eunzipDN (STLEither a b) = + elcase (EVar ext (STLEither (dn a) (dn b)) IZ) + (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b))) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2)) + eunzipDN (STMaybe t) = + emaybe (EVar ext (STMaybe (dn t)) IZ) + (EPair ext (ENothing ext t) (ENothing ext (tanty t))) + (eunPair (eunzipDN t) $ \_ e1 e2 -> + EPair ext (EJust ext e1) (EJust ext e2)) + eunzipDN (STArr n t) = + elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $ + EPair ext (emap (EFst ext (evar IZ)) (evar IZ)) + (emap (ESnd ext (evar IZ)) (evar IZ)) + eunzipDN (STScal st) = case st of + STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ + STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ + STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext) + STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext) + STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext) + eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program" + +-- | Interpret an expression in a given environment. +interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t +interpret = Interpreter.interpretOpen False knownEnv + +-- | Special case of 'interpret' for an expression with a single free variable. +interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t +interpret1 x = interpret (Value x `SCons` SNil) + +-- | Compile an expression to C, load the resulting shared object into the +-- program and wrap it in a Haskell function. +compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t)) +compile = Compile.compileStderr knownEnv + +-- | Special case of 'compile' for an expression with a single free variable. +compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t)) +compile1 term = do + f <- Compile.compileStderr knownEnv term + return (\x -> f (Value x `SCons` SNil)) + +-- | Simplify an expression. The 'vjp'/'jvp' functions already do this automatically. +fullSimplify :: KnownEnv env => Ex env t -> Ex env t +fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix diff --git a/src/AST.hs b/src/CHAD/AST.hs index 5aab4fc..ce9eb20 100644 --- a/src/AST.hs +++ b/src/CHAD/AST.hs @@ -1,14 +1,12 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -16,20 +14,20 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-} -module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where +module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where import Data.Functor.Const import Data.Functor.Identity +import Data.Int (Int64) import Data.Kind (Type) -import Array -import AST.Accum -import AST.Sparse.Types -import AST.Types -import AST.Weaken -import CHAD.Types -import Data +import CHAD.Array +import CHAD.AST.Accum +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types -- General assumption: head of the list (whatever way it is associated) is the @@ -61,12 +59,35 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) + + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative + -> Expr x (TPair t1 t1 : env) (TPair t1 b) + -> Expr x env t1 + -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative + -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -81,6 +102,9 @@ data Expr x env t where -- be backpropagated to; 'a' is the inactive part. The dual field of -- ECustom does not allow a derivative to be generated for 'a', and hence -- none is propagated. + -- No accumulators are allowed inside a, b and tape. This restriction is + -- currently not used very much, so could be relaxed in the future; be sure + -- to check this requirement whenever it is necessary for soundness! ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x [b, a] t -- ^ regular operation -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass @@ -115,6 +139,14 @@ data Expr x env t where EError :: x a -> STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) +-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full +-- 'Expr' type is parametrised on an indexed type of "additional info" (@x@); +-- 'Ex' sets this to nothing. +-- +-- Construct expressions using the functions in "CHAD.Language". +-- +-- Use 'CHAD.AST.Pretty.pprintExpr' or 'CHAD.AST.Pretty.ppExpr' to inspect +-- expressions. type Ex = Expr (Const ()) ext :: Const () a @@ -206,12 +238,18 @@ typeOf = \case EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) + EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t + EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) + + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -253,12 +291,17 @@ extOf = \case ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x + EMap x _ _ -> x EFold1Inner x _ _ _ _ -> x ESum1Inner x _ -> x EUnit x _ -> x EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x + EReshape x _ _ _ -> x + EZip x _ _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -299,12 +342,17 @@ travExt f = \case ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e EUnit x e -> EUnit <$> f x <*> travExt f e EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b + EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b @@ -358,12 +406,17 @@ subst' f w = \case ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) + EZip x a b -> EZip x (subst' f w a) (subst' f w b) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) @@ -447,6 +500,16 @@ envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + EFst _ e -> cheapExpr e + ESnd _ e -> cheapExpr e + EUnit _ e -> cheapExpr e + _ -> False + eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) @@ -473,34 +536,24 @@ eidxEq (SS n) a b emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) emap f arr - | STArr n t <- typeOf arr + | STArr _ t <- typeOf arr , Dict <- styKnown t - = ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f + = EMap ext f arr ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) ezipWith f arr1 arr2 - | STArr n t1 <- typeOf arr1 + | STArr _ t1 <- typeOf arr1 , STArr _ t2 <- typeOf arr2 , Dict <- styKnown t1 , Dict <- styKnown t2 - = ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f + = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) + IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) + IS (IS i) -> EVar ext t (IS i)) + f) + (EZip ext arr1 arr2) ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = - let STArr _ t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 +ezip = EZip ext eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) @@ -516,6 +569,23 @@ eshapeEmpty (SS n) e = (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) +eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) +eshapeConst ShNil = ENil ext +eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) + +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + -- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) -- ezeroD2 t ezi = EZero ext (d2M t) ezi @@ -527,6 +597,7 @@ eshapeEmpty (SS n) e = eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) eunPair e k = elet e $ k WSink @@ -544,7 +615,13 @@ esnd e = ESnd ext e elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b elet rhs body | Dict <- styKnown (typeOf rhs) - = ELet ext rhs body + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body + +-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b emaybe e a b @@ -552,7 +629,14 @@ emaybe e a b , Dict <- styKnown t = EMaybe ext a b e -elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c +ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +ecase e a b + | STEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ECase ext e a b + +elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c elcase e a b c | STLEither t1 t2 <- typeOf e , Dict <- styKnown t1 @@ -573,3 +657,53 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext + +splitSparsePair + :: -- given a sparsity + STy (TPair a b) -> Sparse (TPair a b) t' + -> (forall a' b'. + -- I give you back two sparsities for a and b + Sparse a a' -> Sparse b b' + -- furthermore, I tell you that either your t' is already this (a', b') pair... + -> Either + (t' :~: TPair a' b') + -- or I tell you how to construct a' and b' from t', given an actual t' + (forall r' env. + Idx env t' + -> (forall env'. + (forall c. Ex env' c -> Ex env c) + -> Ex env' a' -> Ex env' b' -> r') + -> r') + -> r) + -> r +splitSparsePair _ SpAbsent k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = + k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = + let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in + k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> + k2 (elet $ + emaybe (EVar ext (STMaybe (applySparse s t)) i) + (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) + (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) + (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = + splitSparsePair t (SpSparse s) $ \s1 s2 eres -> + k s1 s2 $ Right $ \i k2 -> + case eres of + Left refl -> case refl of {} + Right f -> + f IZ $ \wrap e1 e2 -> + k2 (\body -> + elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) + (ENothing ext (applySparse s t)) + (evar IZ)) $ + wrap body) + e1 e2 diff --git a/src/AST/Accum.hs b/src/CHAD/AST/Accum.hs index 988a450..ea74a95 100644 --- a/src/AST/Accum.hs +++ b/src/CHAD/AST/Accum.hs @@ -5,10 +5,10 @@ {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -module AST.Accum where +module CHAD.AST.Accum where -import AST.Types -import Data +import CHAD.AST.Types +import CHAD.Data data AcPrj diff --git a/src/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs index 2310f4b..c1a1e77 100644 --- a/src/AST/Bindings.hs +++ b/src/CHAD/AST/Bindings.hs @@ -13,12 +13,12 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module AST.Bindings where +module CHAD.AST.Bindings where -import AST -import AST.Env -import Data -import Lemmas +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data +import CHAD.Lemmas -- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. @@ -28,6 +28,10 @@ data Bindings f env binds where deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') infixl `BPush` +bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush b e = b `BPush` (typeOf e, e) +infixl `bpush` + mapBindings :: (forall env' t'. f env' t' -> g env' t') -> Bindings f env binds -> Bindings g env binds mapBindings _ BTop = BTop @@ -42,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) = let (b', w') = weakenBindings wf w b in (BPush b' (t, wf w' x), WCopy w') +weakenBindingsE :: env1 :> env2 + -> Bindings (Expr x) env1 binds + -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) +weakenBindingsE = weakenBindings weakenExpr + weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' weakenOver SNil w = w weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs new file mode 100644 index 0000000..46173d2 --- /dev/null +++ b/src/CHAD/AST/Count.hs @@ -0,0 +1,927 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +module CHAD.AST.Count where + +import Data.Functor.Product +import Data.Some +import Data.Type.Equality +import GHC.Generics (Generic, Generically(..)) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data + + +-- | The monoid operation combines assuming that /both/ branches are taken. +class Monoid a => Occurrence a where + -- | One of the two branches is taken + (<||>) :: a -> a -> a + -- | This code is executed many times + scaleMany :: a -> a + + +data Count = Zero | One | Many + deriving (Show, Eq, Ord) + +instance Semigroup Count where + Zero <> n = n + n <> Zero = n + _ <> _ = Many +instance Monoid Count where + mempty = Zero +instance Occurrence Count where + (<||>) = max + scaleMany Zero = Zero + scaleMany _ = Many + +data Occ = Occ { _occLexical :: Count + , _occRuntime :: Count } + deriving (Eq, Generic) + deriving (Semigroup, Monoid) via Generically Occ + +instance Show Occ where + showsPrec d (Occ l r) = showParen (d > 10) $ + showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r + +instance Occurrence Occ where + Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) + scaleMany (Occ l c) = Occ l (scaleMany c) + + +data Substruc t t' where + -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below + SsFull :: Substruc t t + SsNone :: Substruc t TNil + SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') + SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') + SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') + SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') + SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements + SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') + +pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' +pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) + where SsPair' = SsPair +{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} + +pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' +pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) + where SsArr' = SsArr +{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} + +instance Semigroup (Some (Substruc t)) where + Some SsFull <> _ = Some SsFull + _ <> Some SsFull = Some SsFull + Some SsNone <> s = s + s <> Some SsNone = s + Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) + Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) + Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) + Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) + Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) + Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) +instance Monoid (Some (Substruc t)) where + mempty = Some SsNone + +instance TestEquality (Substruc t) where + testEquality SsFull s = isFull s + testEquality s SsFull = sym <$> isFull s + testEquality SsNone SsNone = Just Refl + testEquality SsNone _ = Nothing + testEquality _ SsNone = Nothing + testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + +isFull :: Substruc t t' -> Maybe (t :~: t') +isFull SsFull = Just Refl +isFull SsNone = Nothing -- TODO: nil? +isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing + +applySubstruc :: Substruc t t' -> STy t -> STy t' +applySubstruc SsFull t = t +applySubstruc SsNone _ = STNil +applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) +applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) +applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) + +applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' +applySubstrucM SsFull t = t +applySubstrucM SsNone _ = SMTNil +applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) +applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) +applySubstrucM _ t = case t of {} + +data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) + | a ~ b => ExMapId + +fromExMap :: ExMap a b -> Ex env a -> Ex env b +fromExMap (ExMap f) = f +fromExMap ExMapId = id + +simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' +simplifySubstruc STNil SsNone = SsFull + +simplifySubstruc _ SsFull = SsFull +simplifySubstruc _ SsNone = SsNone +simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) +simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) +simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) + +-- simplifySubstruc' :: Substruc t t' +-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r +-- simplifySubstruc' SsFull k = k SsFull ExMapId +-- simplifySubstruc' SsNone k = k SsNone ExMapId +-- simplifySubstruc' (SsPair s1 s2) k = +-- simplifySubstruc' s1 $ \s1' f1 -> +-- simplifySubstruc' s2 $ \s2' f2 -> +-- case (s1', s2') of +-- (SsFull, SsFull) -> +-- k SsFull (case (f1, f2) of +-- (ExMapId, ExMapId) -> ExMapId +-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> +-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) +-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) +-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) +-- simplifySubstruc' _ _ = _ + +-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) +-- ssUnpair SsFull = (SsFull, SsFull) +-- ssUnpair SsNone = (SsNone, SsNone) +-- ssUnpair (SsPair a b) = (a, b) + +-- ssUnleft :: Substruc (TEither a b) -> Substruc a +-- ssUnleft SsFull = SsFull +-- ssUnleft SsNone = SsNone +-- ssUnleft (SsEither a _) = a + +-- ssUnright :: Substruc (TEither a b) -> Substruc b +-- ssUnright SsFull = SsFull +-- ssUnright SsNone = SsNone +-- ssUnright (SsEither _ b) = b + +-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a +-- ssUnlleft SsFull = SsFull +-- ssUnlleft SsNone = SsNone +-- ssUnlleft (SsLEither a _) = a + +-- ssUnlright :: Substruc (TLEither a b) -> Substruc b +-- ssUnlright SsFull = SsFull +-- ssUnlright SsNone = SsNone +-- ssUnlright (SsLEither _ b) = b + +-- ssUnjust :: Substruc (TMaybe a) -> Substruc a +-- ssUnjust SsFull = SsFull +-- ssUnjust SsNone = SsNone +-- ssUnjust (SsMaybe a) = a + +-- ssUnarr :: Substruc (TArr n a) -> Substruc a +-- ssUnarr SsFull = SsFull +-- ssUnarr SsNone = SsNone +-- ssUnarr (SsArr a) = a + +-- ssUnaccum :: Substruc (TAccum a) -> Substruc a +-- ssUnaccum SsFull = SsFull +-- ssUnaccum SsNone = SsNone +-- ssUnaccum (SsAccum a) = a + + +type family MapEmpty env where + MapEmpty '[] = '[] + MapEmpty (t : env) = TNil : MapEmpty env + +data OccEnv a env env' where + OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! + OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') + +instance Semigroup a => Semigroup (Some (OccEnv a env)) where + Some OccEnd <> e = e + e <> Some OccEnd = e + Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) + +instance Semigroup a => Monoid (Some (OccEnv a env)) where + mempty = Some OccEnd + +instance Occurrence a => Occurrence (Some (OccEnv a env)) where + Some OccEnd <||> e = e + e <||> Some OccEnd = e + Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) + + scaleMany (Some OccEnd) = Some OccEnd + scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) + +onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) +onehotOccEnv IZ v s = Some (OccPush OccEnd v s) +onehotOccEnv (IS i) v s + | Some env' <- onehotOccEnv i v s + = Some (OccPush env' mempty SsNone) + +occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') +occEnvPop (OccPush e _ s) = (e, s) +occEnvPop OccEnd = (OccEnd, SsNone) + +occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r +occEnvPop' (OccPush e _ s) k = k e s +occEnvPop' OccEnd k = k OccEnd SsNone + +occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) +occEnvPopSome (Some (OccPush e _ _)) = Some e +occEnvPopSome (Some OccEnd) = Some OccEnd + +occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) +occEnvPrj OccEnd _ = mempty +occEnvPrj (OccPush _ o s) IZ = (o, Some s) +occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i + +occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) +occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) +occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) +occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) +occEnvPrjS (OccPush e _ _) (IS i) + | Some (Pair s' i') <- occEnvPrjS e i + = Some (Pair s' (IS i')) + +projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small +projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of + _ | Just Refl <- testEquality topsbig topssmall -> ex + + (SsFull, SsFull) -> ex + (SsNone, SsNone) -> ex + (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" + (_, SsNone) -> + case typeOf ex of + STNil -> ex + _ -> use ex $ ENil ext + + (SsPair s1 s2, SsPair s1' s2') -> + eunPair ex $ \_ e1 e2 -> + EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) + (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex + (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex + + (SsEither s1 s2, SsEither s1' s2') + | STEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in ecase ex + (EInl ext (typeOf e2) e1) + (EInr ext (typeOf e1) e2) + (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex + (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex + + (SsLEither s1 s2, SsLEither s1' s2') + | STLEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in elcase ex + (ELNil ext (typeOf e1) (typeOf e2)) + (ELInl ext (typeOf e2) e1) + (ELInr ext (typeOf e1) e2) + (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex + (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex + + (SsMaybe s1, SsMaybe s1') + | STMaybe t1 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + in emaybe ex + (ENothing ext (typeOf e1)) + (EJust ext e1) + (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex + (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex + + (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex + (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex + (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex + + (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" + (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex + (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex + + +-- | A boolean for each entry in the environment, with the ability to uniformly +-- mask the top part above a certain index. +data EnvMask env where + EMRest :: Bool -> EnvMask env + EMPush :: EnvMask env -> Bool -> EnvMask (t : env) + +envMaskPrj :: EnvMask env -> Idx env t -> Bool +envMaskPrj (EMRest b) _ = b +envMaskPrj (_ `EMPush` b) IZ = b +envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx ex + | Some env <- occCountAll ex + = fst (occEnvPrj env idx) + +occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll ex = occCountX SsFull ex $ \env _ -> Some env + +pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) + where + fullOccEnv :: SList f env -> OccEnv () env env + fullOccEnv SNil = OccEnd + fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull + +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse +-- * k: is passed the actual environment usage of this expression, including +-- occurrence counts. The callback reconstructs a new expression in an +-- updated "response" environment. The response must be at least as large as +-- the computed usages. +occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t + -> (forall env'. OccEnv Occ env env' + -- response OccEnv must be at least as large as the OccEnv returned above + -> (forall env''. OccEnv () env env'' -> Ex env'' t') + -> r) + -> r +occCountX initialS topexpr k = case topexpr of + EVar _ t i -> + withSome (onehotOccEnv i (Occ One One) s) $ \env -> + k env $ \env' -> + withSome (occEnvPrjS env' i) $ \(Pair s' i') -> + projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') + ELet _ rhs body -> + occCountX s body $ \envB mkbody -> + occEnvPop' envB $ \envB' s1 -> + occCountX s1 rhs $ \envR mkrhs -> + withSome (Some envB' <> Some envR) $ \env -> + k env $ \env' -> + ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) + EPair _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsPair' s1 s2 -> + occCountX s1 a $ \env1 mka -> + occCountX s2 b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPair ext (mka env') (mkb env') + EFst _ e -> + occCountX (SsPair s SsNone) e $ \env1 mke -> + k env1 $ \env' -> + EFst ext (mke env') + ESnd _ e -> + occCountX (SsPair SsNone s) e $ \env1 mke -> + k env1 $ \env' -> + ESnd ext (mke env') + ENil _ -> + case s of + SsFull -> k OccEnd (\_ -> ENil ext) + SsNone -> k OccEnd (\_ -> ENil ext) + EInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + EInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + EInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + EInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + ECase _ e a b -> + occCountX s a $ \env1' mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env1' $ \env1 s1 -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) + ENothing _ t -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EJust _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsMaybe s' -> + occCountX s' e $ \env1 mke -> + k env1 $ \env' -> + EJust ext (mke env') + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EMaybe _ a b e -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsMaybe s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') + ELNil _ t1 t2 -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + ELInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + ELInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELCase _ e a b c -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occCountX s c $ \env3' mkc -> + occEnvPop' env2' $ \env2 s1 -> + occEnvPop' env3' $ \env3 s2 -> + occCountX (SsLEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> + k env $ \env' -> + ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) + + EConstArr _ n t x -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) + + EBuild _ n a b -> + case s of + SsNone -> + occCountX SsFull a $ \env1 mka -> + occCountX SsNone b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + use (EBuild ext n (mka env') $ + use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ + ENil ext) $ + ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX s' b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + EBuild ext n (mka env') $ + elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + + EFold1Inner _ commut a b c -> + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of + SsNone -> Some SsNone + SsArr' s' -> Some s' in + withSome (s1 <> s0) $ \sElt -> + occCountX sElt b $ \env2 mkb -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush env' () (SsPair sElt sElt))) + (mkb env') (mkc env') + + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e + + EUnit _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' s' -> + occCountX s' e $ \env mke -> + k env $ \env' -> + EUnit ext (mke env') + + EReplicate1Inner _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX (SsArr s') b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReplicate1Inner ext (mka env') (mkb env') + + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' SsNone s2) -> + occCountX SsNone a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ + emap (EPair ext (ENil ext) (evar IZ)) (mkb env') + SsArr' (SsPair' s1 SsNone) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ + emap (EPair ext (evar IZ) (ENil ext)) (mka env') + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + + EFold1InnerD1 _ cm e1 e2 e3 -> + case s of + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm ef ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkebog env') (mked env') + + EConst _ t x -> + k OccEnd $ \_ -> + case s of + SsNone -> ENil ext + SsFull -> EConst ext t x + + EIdx0 _ e -> + occCountX (SsArr s) e $ \env1 mke -> + k env1 $ \env' -> + EIdx0 ext (mke env') + + EIdx1 _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX (SsArr s') a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx1 ext (mka env') (mkb env') + + EIdx _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + _ -> + occCountX (SsArr s) a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx ext (mka env') (mkb env') + + EShape _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX (SsArr SsNone) e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EShape ext (mke env') + + EOp _ op e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EOp ext op (mke env') + + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') + + ERecompute _ e -> + occCountX s e $ \env1 mke -> + k env1 $ \env' -> + ERecompute ext (mke env') + + EWith _ t a b -> + case s of + SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with + occCountX SsNone b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + withSome (case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone) $ \s1' -> + occCountX s1' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ + ENil ext + SsPair sB sA -> + occCountX sB b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + let s1' = case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone in + withSome (Some sA <> s1') $ \sA' -> + occCountX sA' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ + EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) + SsFull -> occCountX (SsPair SsFull SsFull) topexpr k + + EAccum _ t p a sp b e -> + -- TODO: do better! + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull e $ \env3 mke -> + withSome (Some env1 <> Some env2) $ \env12 -> + withSome (Some env12 <> Some env3) $ \env -> + k env $ \env' -> + case s of {SsFull -> id; SsNone -> id} $ + EAccum ext t p (mka env') sp (mkb env') (mke env') + + EZero _ t e -> + occCountX (subZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EZero ext (applySubstrucM s t) (mke env') + where + subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) + subZeroInfo SsFull = SsFull + subZeroInfo SsNone = SsNone + subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) + subZeroInfo SsEither{} = error "Either is not a monoid" + subZeroInfo SsLEither{} = SsNone + subZeroInfo SsMaybe{} = SsNone + subZeroInfo (SsArr s') = SsArr (subZeroInfo s') + subZeroInfo SsAccum{} = error "Accum is not a monoid" + + EDeepZero _ t e -> + occCountX (subDeepZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EDeepZero ext (applySubstrucM s t) (mke env') + where + subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) + subDeepZeroInfo SsFull = SsFull + subDeepZeroInfo SsNone = SsNone + subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo SsEither{} = error "Either is not a monoid" + subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') + subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') + subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" + + EPlus _ t a b -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPlus ext (applySubstrucM s t) (mka env') (mkb env') + + EOneHot _ t p a b -> + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> -- TODO: do better + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') + + EError _ t msg -> + k OccEnd $ \_ -> EError ext (applySubstruc s t) msg + where + s = simplifySubstruc (typeOf topexpr) initialS + + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr' SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') + + +deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil (Some OccEnd) k = k SETop +deleteUnused (_ `SCons` env) (Some OccEnd) k = + deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = + deleteUnused env (Some occenv) $ \sub -> + case count of Zero -> k (SENo sub) + _ -> k (SEYesR sub) + +unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t +unsafeWeakenWithSubenv = \sub -> + subst (\x t i -> case sinkViaSubenv i sub of + Just i' -> EVar x t i' + Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") + where + sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) + sinkViaSubenv IZ (SEYesR _) = Just IZ + sinkViaSubenv IZ (SENo _) = Nothing + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/CHAD/AST/Env.hs index 422f0f7..8e6b745 100644 --- a/src/AST/Env.hs +++ b/src/CHAD/AST/Env.hs @@ -7,14 +7,14 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} -module AST.Env where +module CHAD.AST.Env where import Data.Type.Equality -import AST.Sparse -import AST.Weaken -import CHAD.Types -import Data +import CHAD.AST.Sparse +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types -- | @env'@ is a subset of @env@: each element of @env@ is either included in @@ -64,6 +64,16 @@ subenvConcat sub1 SETop = sub1 subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) +-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 +-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r +-- subenvSplit SNil sub k = k SETop sub +-- subenvSplit (SCons _ list) (SENo sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SENo sub1) sub2 +-- subenvSplit (SCons _ list) (SEYes s sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SEYes s sub1) sub2 + sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 sinkWithSubenv SETop = WId sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub diff --git a/src/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index fef9686..9ddcb35 100644 --- a/src/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -1,33 +1,31 @@ -{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where +module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) import Data.Functor.Const -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.String (fromString) import Prettyprinter import Prettyprinter.Render.String -import qualified Data.Text.Lazy as TL -import qualified Prettyprinter.Render.Terminal as PT +import Data.Text.Lazy qualified as TL +import Prettyprinter.Render.Terminal qualified as PT import System.Console.ANSI (hSupportsANSI) import System.IO (stdout) import System.IO.Unsafe (unsafePerformIO) -import AST -import AST.Count -import AST.Sparse.Types -import CHAD.Types -import Data +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types class PrettyX x where @@ -71,6 +69,7 @@ genNameIfUsedIn' prefix ty idx ex _ -> return "_" | otherwise = genName' prefix +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t @@ -202,16 +201,22 @@ ppExpr' d val expr = case expr of <> hardline <> e') (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] + EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - let opname = case cm of Commut -> "fold1i(C)" - Noncommut -> "fold1i" + let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -234,6 +239,38 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] + + EFold1InnerD1 _ cm a b c -> do + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -385,6 +422,10 @@ ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s ppSparse (SMTScal _) SpScal = "." +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -478,4 +519,5 @@ render = else renderString) . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } where + {-# NOINLINE stdoutTTY #-} stdoutTTY = unsafePerformIO $ hSupportsANSI stdout diff --git a/src/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs index 93258b7..85f2882 100644 --- a/src/AST/Sparse.hs +++ b/src/CHAD/AST/Sparse.hs @@ -2,17 +2,15 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE RankNTypes #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} -module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where +module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where import Data.Type.Equality -import AST -import AST.Sparse.Types -import Data (SBool(..)) +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data (SBool(..)) sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' @@ -85,9 +83,6 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g) withInj2 Noinj _ _ = Noinj withInj2 _ Noinj _ = Noinj -use :: Ex env a -> Ex env b -> Ex env b -use a b = elet a $ weakenExpr WSink b - -- | This function produces quadratically-sized code in the presence of nested -- dynamic sparsity. TODO can this be improved? sparsePlusS @@ -206,16 +201,27 @@ sparsePlusS SF _ t (SpSparse sp1) sp2 k = emaybe (weakenExpr WSink a) (inj2 (evar IZ)) (plus (evar IZ) (evar (IS IZ)))) -sparsePlusS ST _ t (SpSparse sp1) sp2 k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpSparse sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> EJust ext (inj2 b)) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (EJust ext (inj2 (evar IZ))) - (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k sp3 + (Inj $ \a -> emaybe a (inj2 zero2) (inj1 (evar IZ))) + (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) + | otherwise = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus) diff --git a/src/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 10cac4e..8f41ba4 100644 --- a/src/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -3,13 +3,13 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module AST.Sparse.Types where - -import AST.Types +module CHAD.AST.Sparse.Types where import Data.Kind (Type, Constraint) import Data.Type.Equality +import CHAD.AST.Types + data Sparse t t' where SpSparse :: Sparse t t' -> Sparse t (TMaybe t') diff --git a/src/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs index dcaf82f..34267e4 100644 --- a/src/AST/SplitLets.hs +++ b/src/CHAD/AST/SplitLets.hs @@ -7,13 +7,13 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module AST.SplitLets (splitLets) where +module CHAD.AST.SplitLets (splitLets) where import Data.Type.Equality -import AST -import AST.Bindings -import Lemmas +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Lemmas splitLets :: Ex env t -> Ex env t @@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) @@ -34,7 +34,14 @@ splitLets' = \sub -> \case in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -49,11 +56,14 @@ splitLets' = \sub -> \case ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) @@ -89,15 +99,42 @@ splitLets' = \sub -> \case -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindings weakenExpr WSink bs1') + bs1 = fst (weakenBindingsE WSink bs1') (ptrs2, bs2) = split @(bind1 : env') tbind2 in letBinds bs1 $ - letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body + -- TODO: abstract this to splitN lol wtf + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + type family Split t where Split (TPair a b) = SplitRec (TPair a b) Split _ = '[] diff --git a/src/AST/Types.hs b/src/CHAD/AST/Types.hs index 42bfb92..f0feb55 100644 --- a/src/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -8,7 +8,7 @@ {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module AST.Types where +module CHAD.AST.Types where import Data.Int (Int32, Int64) import Data.GADT.Compare @@ -16,7 +16,7 @@ import Data.GADT.Show import Data.Kind (Type) import Data.Type.Equality -import Data +import CHAD.Data type data Ty @@ -31,6 +31,8 @@ type data Ty type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool +-- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes +-- convenient, but such scalar types are not special in any way. type STy :: Ty -> Type data STy t where STNil :: STy TNil @@ -171,15 +173,25 @@ type family ScalIsIntegral t where ScalIsIntegral TBool = False -- | Returns true for arrays /and/ accumulators. -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STLEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = True +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True type family Tup env where Tup '[] = TNil diff --git a/src/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index 48dd709..d3cad25 100644 --- a/src/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -3,11 +3,11 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where +module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where -import AST -import AST.Sparse.Types -import Data +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data -- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by @@ -38,12 +38,17 @@ unMonoid = \case ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) @@ -99,13 +104,10 @@ plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -- don't destroy the effects! plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext plus (SMTPair t1 t2) a b = - let t = STPair (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) + eunPair a $ \w1 a1 a2 -> + eunPair (weakenExpr w1 b) $ \w2 b1 b2 -> + EPair ext (plus t1 (weakenExpr w2 a1) b1) + (plus t2 (weakenExpr w2 a2) b2) plus (SMTLEither t1 t2) a b = let t = STLEither (fromSMTy t1) (fromSMTy t2) in ELet ext a $ diff --git a/src/AST/Weaken.hs b/src/CHAD/AST/Weaken.hs index d882e28..ac0d152 100644 --- a/src/AST/Weaken.hs +++ b/src/CHAD/AST/Weaken.hs @@ -15,14 +15,15 @@ -- The reason why this is a separate module with "little" in it: {-# LANGUAGE AllowAmbiguousTypes #-} -module AST.Weaken (module AST.Weaken, Append) where +module CHAD.AST.Weaken (module CHAD.AST.Weaken, Append) where import Data.Bifunctor (first) import Data.Functor.Const +import Data.GADT.Compare import Data.Kind (Type) -import Data -import Lemmas +import CHAD.Data +import CHAD.Lemmas type Idx :: [k] -> k -> Type @@ -31,6 +32,11 @@ data Idx env t where IS :: Idx env t -> Idx (a : env) t deriving instance Show (Idx env t) +instance GEq (Idx env) where + geq IZ IZ = Just Refl + geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl + geq _ _ = Nothing + splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) splitIdx SNil i = Right i splitIdx (SCons _ _) IZ = Left IZ @@ -123,7 +129,7 @@ wCopies bs w = let bs' = slistMap (\_ -> Const ()) bs in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/CHAD/AST/Weaken/Auto.hs index c6efe37..229940b 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/CHAD/AST/Weaken/Auto.hs @@ -1,35 +1,34 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS_GHC -Wno-partial-type-signatures #-} -module AST.Weaken.Auto ( +module CHAD.AST.Weaken.Auto ( autoWeak, (&.), auto, auto1, Layout(..), ) where import Data.Functor.Const +import Data.Kind (Constraint) import GHC.OverloadedLabels import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import AST.Weaken -import Data -import Lemmas +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Lemmas type family Lookup name list where @@ -39,18 +38,21 @@ type family Lookup name list where -- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) +-- occurrences within this layout. 'names' is the list of names that this +-- layout /produces/. That is: for LPreW, it contains the target name. The +-- 'names' list of a source layout must be a subset of the names list of the +-- target layout (which cannot contain LPreW); this is checked with SubLayout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) -- | Pre-weaken with a weakening LPreW :: forall name1 name2 segments. SegmentName name1 -> SegmentName name2 -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments (Lookup name1 segments) - (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2) + -> Layout True segments '[name2] (Lookup name1 segments) + (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) infixr :++: -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where +instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where fromLabel = LSeg (symbolSing @name) newtype SegmentName name = SegmentName (SSymbol name) @@ -60,6 +62,18 @@ instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') wh fromLabel = SegmentName symbolSing +type family SubLayout names1 names2 where + SubLayout '[] _ = () :: Constraint + SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 +type family SubLayout' n ok names1 names2 where + SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") + SubLayout' _ True names1 names2 = SubLayout names1 names2 +type family Contains n names where + Contains _ '[] = False + Contains n (n : _) = True + Contains n (_ : names) = Contains n names + + data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) @@ -74,7 +88,7 @@ auto1 :: SList (Const ()) '[t] auto1 = Const () `SCons` SNil infixr &. -(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) (&.) = ssegmentsAppend where ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) @@ -118,12 +132,12 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ seg) +lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinApp name LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinAppPreW name1 name2 w LinEnd @@ -151,8 +165,7 @@ pullDown segs name@SSymbol linlayout kNotFound k = k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) .> wCopies (segmentLookup segs n') w) -sortLinLayouts :: forall segments env1 env2. - SSegments segments +sortLinLayouts :: SSegments segments -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) @@ -169,8 +182,8 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" -autoWeak :: forall segments env1 env2. - SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak :: SubLayout names1 names2 + => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 autoWeak segs ly1 ly2 = preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs index b54946b..212cc7d 100644 --- a/src/Analysis/Identity.hs +++ b/src/CHAD/Analysis/Identity.hs @@ -3,7 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -module Analysis.Identity ( +module CHAD.Analysis.Identity ( identityAnalysis, identityAnalysis', ValId(..), @@ -13,11 +13,11 @@ module Analysis.Identity ( import Data.Foldable (toList) import Data.List (intercalate) -import AST -import AST.Pretty (PrettyX(..)) -import CHAD.Types (d1, d2) -import Data -import Util.IdGen +import CHAD.AST +import CHAD.AST.Pretty (PrettyX(..)) +import CHAD.Data +import CHAD.Drev.Types (d1, d2) +import CHAD.Util.IdGen -- | Every array, scalar and accumulator has an ID. Trivial values such as @@ -202,11 +202,19 @@ idana env expr = case expr of res <- VIArr <$> genId <*> shidsToVec dim shids pure (res, EBuild res dim e1' e2') + EMap _ e1 e2 -> do + let STArr _ t = typeOf e2 + x1 <- genIds t + (_, e1') <- idana (x1 `SCons` env) e1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure sh + pure (res, EMap res e1' e2') + EFold1Inner _ cm e1 e2 e3 -> do let t1 = typeOf e1 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 @@ -244,6 +252,41 @@ idana env expr = case expr of res <- VIArr <$> genId <*> pure sh pure (res, EMinimum1Inner res e1') + EReshape _ dim e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> shidsToVec dim v1 + pure (res, EReshape res dim e1' e2') + + EZip _ e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + let VIArr _ sh = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EZip res e1' e2') + + EFold1InnerD1 _ cm e1 e2 e3 -> do + let t1 = typeOf e2 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ sh'@(_ :< sh) = v3 + res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') + pure (res, EFold1InnerD1 res cm e1' e2' e3') + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + xf1 <- genIds t2 + xf2 <- genIds tB + (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef + (v2, e2') <- idana env ebog + (_, e3') <- idana env ed + let VIArr _ sh@(_ :< sh') = v2 + res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3') + EConst _ t val -> do res <- VIScal <$> genId pure (res, EConst res t val) diff --git a/src/Array.hs b/src/CHAD/Array.hs index 707dce2..caf63ef 100644 --- a/src/Array.hs +++ b/src/CHAD/Array.hs @@ -2,19 +2,20 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} -module Array where +module CHAD.Array where import Control.DeepSeq import Control.Monad.Trans.State.Strict import Data.Foldable (traverse_) import Data.Vector (Vector) -import qualified Data.Vector as V +import Data.Vector qualified as V import GHC.Generics (Generic) -import Data +import CHAD.Data data Shape n where @@ -91,6 +92,11 @@ arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) arrayToList :: Array n t -> [t] arrayToList (Array _ v) = V.toList v +arrayReshape :: Shape n -> Array m t -> Array n t +arrayReshape sh (Array sh' v) + | shapeSize sh == shapeSize sh' = Array sh v + | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" + arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) diff --git a/src/Compile.hs b/src/CHAD/Compile.hs index a5c4fb7..44a335c 100644 --- a/src/Compile.hs +++ b/src/CHAD/Compile.hs @@ -2,13 +2,14 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} -module Compile (compile) where +module CHAD.Compile (compile, compileStderr) where import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) @@ -20,36 +21,37 @@ import Data.Bifunctor (first) import Data.Char (ord) import Data.Foldable (toList) import Data.Functor.Const -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.Functor.Product (Product) import Data.IORef import Data.List (foldl1', intersperse, intercalate) -import qualified Data.Map.Strict as Map +import Data.Map.Strict qualified as Map import Data.Maybe (fromMaybe) -import qualified Data.Set as Set +import Data.Set qualified as Set import Data.Set (Set) import Data.Some -import qualified Data.Vector as V +import Data.Vector qualified as V import Foreign import GHC.Exts (int2Word#, addr2Int#) import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) import Numeric (showHex) import System.IO (hPutStrLn, stderr) import System.IO.Error (mkIOError, userErrorType) import System.IO.Unsafe (unsafePerformIO) import Prelude hiding ((^)) -import qualified Prelude +import Prelude qualified -import Array -import AST -import AST.Pretty (ppSTy, ppExpr) -import AST.Sparse.Types (isDense) -import Compile.Exec -import Data -import Interpreter.Rep -import qualified Util.IdGen as IdGen +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty (ppSTy, ppExpr) +import CHAD.AST.Sparse.Types (isDense) +import CHAD.Compile.Exec +import CHAD.Data +import CHAD.Interpreter.Rep +import CHAD.Util.IdGen qualified as IdGen -- In shape and index arrays, the innermost dimension is on the right (last index). @@ -70,28 +72,30 @@ debugAllocs :: Bool; debugAllocs = toEnum 0 -- | Emit extra C code that checks stuff emitChecks :: Bool; emitChecks = toEnum 0 +-- | Returns compiled function plus compilation output (warnings) compile :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t)) + -> IO (SList Value env -> IO (Rep t), String) compile = \env expr -> do codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) let (source, offsets) = compileToString codeID env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - lib <- buildKernel source "kernel" + (lib, compileOutput) <- buildKernel source "kernel" let result_type = typeOf expr result_size = sizeofSTy result_type - return $ \val -> do - allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) - serialiseArguments args ptr $ do - callKernelFun lib ptr - ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) - when (ok /= 1) $ - ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) - deserialise result_type ptr (koResultOffset offsets) + let function val = do + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) + serialiseArguments args ptr $ do + callKernelFun lib ptr + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) + return (function, compileOutput) where serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -99,6 +103,15 @@ compile = \env expr -> do serialiseArguments args ptr k serialiseArguments _ _ k = k +-- | 'compile', but writes any produced C compiler output to stderr. +compileStderr :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t)) +compileStderr env expr = do + (fun, output) <- compile env expr + when (not (null output)) $ + hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun + data StructDecl = StructDecl String -- ^ name @@ -126,7 +139,7 @@ data CExpr | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr - | CECast String CExpr -- ^ (<type)<expr> + | CECast String CExpr -- ^ (<type>)<expr> deriving (Show) printStructDecl :: StructDecl -> ShowS @@ -215,23 +228,31 @@ repSTy (STScal st) = case st of STBool -> "uint8_t" repSTy t = genStructName t -genStructName :: STy t -> String -genStructName = \t -> "ty_" ++ gen t where - -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STLEither a b) = 'L' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen (fromSMTy t) +genStructName, genArrBufStructName :: STy t -> String +(genStructName, genArrBufStructName) = + (\t -> "ty_" ++ gen t + ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension + t -> error $ "genArrBufStructName: not an array type: " ++ show t) + where + -- all tags start with a letter, so the array mangling is unambiguous. + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) + +-- The subtrees contain structs used in the bodies of the structs in this node. +data StructTree = TreeNode [StructDecl] [StructTree] + deriving (Show) -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the @@ -239,60 +260,56 @@ genStructName = \t -> "ty_" ++ gen t where -- -- For accumulation it is important that for struct representations of monoid -- types, the all-zero-bytes value corresponds to the zero value of that type. -genStruct :: String -> STy t -> [StructDecl] -genStruct name topty = case topty of +buildStructTree :: STy t -> StructTree +buildStructTree topty = case topty of STNil -> - [StructDecl name "" com] + TreeNode [StructDecl name "" com] [] STPair a b -> - [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + [buildStructTree a, buildStructTree b] STEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] STMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + [buildStructTree t] STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - -- TODO: put shape in the main struct, not the buffer; it's constant, after all -- TODO: no buffer if n = 0 - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" - ,StructDecl name (name ++ "_buf *buf;") com] + TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" + ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] + [buildStructTree t] STScal _ -> - [] + TreeNode [] [] STAccum t -> - [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" - ,StructDecl name (name ++ "_buf *buf;") com] + TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] + [buildStructTree (fromSMTy t)] where + name = genStructName topty com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () -genStructs ty = do - let name = genStructName ty - seen <- lift $ gets (name `Set.member`) - - if seen - then pure () - else do - -- already mark this struct as generated now, so we don't generate it - -- twice (unnecessary because no recursive types, but y'know) - lift $ modify (Set.insert name) - - () <- case ty of - STNil -> pure () - STPair a b -> genStructs a >> genStructs b - STEither a b -> genStructs a >> genStructs b - STLEither a b -> genStructs a >> genStructs b - STMaybe t -> genStructs t - STArr _ t -> genStructs t - STScal _ -> pure () - STAccum t -> genStructs (fromSMTy t) - - tell (BList (genStruct name ty)) +genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () +genStructTreeW (TreeNode these deps) = do + seen <- lift get + case filter ((`Set.notMember` seen) . nameOf) these of + [] -> pure () + structs -> do + lift $ modify (Set.fromList (map nameOf structs) <>) + mapM_ genStructTreeW deps + tell (BList structs) + where + nameOf (StructDecl name _ _) = name genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty +genAllStructs tys = + let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys + in toList (evalState (execWriterT m) mempty) data CompState = CompState { csStructs :: Set (Some STy) @@ -341,6 +358,12 @@ emitStruct ty = CompM $ do modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } return (genStructName ty) +-- | Also returns the name of the array buffer struct +emitArrStruct :: STy t -> CompM (String, String) +emitArrStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty, genArrBufStructName ty) + emitTLD :: String -> CompM () emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } @@ -428,10 +451,10 @@ compileToString codeID env expr = else id ,showString $ " const bool success = typed_kernel(" ++ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ - concat (map (\((arg, typ), off) -> - ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" - ++ " /* " ++ arg ++ " */") - (zip arg_pairs arg_offsets)) ++ + concat (zipWith (\(arg, typ) off -> + ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ " /* " ++ arg ++ " */") + arg_pairs arg_offsets) ++ "\n );\n" ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" @@ -481,19 +504,18 @@ serialise topty topval ptr off k = serialise t x ptr (off + alignmentSTy t) k (STArr n t, Array sh vec) -> do let eltsz = sizeofSTy t - allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do when debugRefc $ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr pokeByteOff ptr off bufptr + pokeShape ptr (off + 8) n sh - pokeShape bufptr 0 n sh - pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63) + pokeByteOff @Word64 bufptr 0 (2 ^ 63) - let off1 = fromSNat n * 8 + 8 - loop i + let loop i | i == shapeSize sh = k | otherwise = - serialise t (vec V.! i) bufptr (off1 + i * eltsz) $ + serialise t (vec V.! i) bufptr (8 + i * eltsz) $ loop (i+1) loop 0 (STScal sty, x) -> case sty of @@ -533,13 +555,12 @@ deserialise topty ptr off = else Just <$> deserialise t ptr (off + alignmentSTy t) STArr n t -> do bufptr <- peekByteOff @(Ptr ()) ptr off - sh <- peekShape bufptr 0 n - refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + sh <- peekShape ptr (off + 8) n + refc <- peekByteOff @Word64 bufptr 0 when debugRefc $ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc - let off1 = 8 * fromSNat n + 8 - eltsz = sizeofSTy t - arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) + let eltsz = sizeofSTy t + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) when (refc < 2 ^ 62) $ free bufptr return arr STScal sty -> case sty of @@ -577,7 +598,7 @@ metricsSTy (STLEither a b) = metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr _ _) = (8, 8) +metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) metricsSTy (STScal sty) = case sty of STI32 -> (4, 4) STI64 -> (8, 8) @@ -600,7 +621,7 @@ peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) peekShape ptr off = \case SZ -> return ShNil SS n -> ShCons <$> peekShape ptr off n - <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) + <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + fromSNat n * 8)) compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case @@ -748,15 +769,17 @@ compile' env = \case return (CELit retvar) EConstArr _ n t (Array sh vec) -> do - strname <- emitStruct (STArr n (STScal t)) + (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" -- Give it a refcount of _half_ the size_t max, so that it can be -- incremented and decremented at will and will "never" reach anything -- where something happens - emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++ - "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ - ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" - return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) + emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ + "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ + ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname + [("buf", CEAddrOf (CELit tldname)) + ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) EBuild _ n esh efun -> do shname <- compileAssign "sh" env esh @@ -771,7 +794,7 @@ compile' env = \case emit $ SBlock $ pure (SVarDecl False "size_t" linivar (CELit "0")) <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx)))) + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) | (ivar, dimidx) <- zip ivars [0::Int ..]] (pure (SVarDecl True (repSTy (typeOf esh)) idxargname (shapeTupFromLitVars n ivars)) @@ -780,6 +803,15 @@ compile' env = \case return (CELit arrname) + -- TODO: actually generate decent code here + EMap _ e1 e2 -> do + let STArr n _ = typeOf e2 + compile' env $ + elet e2 $ + EBuild ext n (EShape ext (evar IZ)) $ + elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e1 + EFold1Inner _ commut efun ex0 earr -> do let STArr (SS n) t = typeOf earr @@ -800,7 +832,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name @@ -809,22 +841,26 @@ compile' env = \case -- kvar <- if vecwid > 1 then genName' "k" else return "" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> funStmts - <> pure (SAsg accvar funres)) + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SAsg accvar funres)) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) incrementVarAlways "foldx0" Decrement t x0name @@ -846,7 +882,7 @@ compile' env = \case lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) let vecwid = 8 :: Int ivar <- genName' "i" @@ -910,6 +946,149 @@ compile' env = \case EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + + -- TODO: actually generate decent code here + EZip _ e1 e2 -> do + let STArr n _ = typeOf e1 + compile' env $ + elet e1 $ + elet (weakenExpr WSink e2) $ + EBuild ext n (EShape ext (evar (IS IZ))) $ + EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) + + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + pairstrname <- emitStruct (STPair t t) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -935,7 +1114,7 @@ compile' env = \case when emitChecks $ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" - (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ arrname ++ ".buf); return false;") @@ -991,7 +1170,7 @@ compile' env = \case accname <- genName' "accum" emit $ SVarDecl False actyname accname (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) - emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) + emit $ SAsg (accname++".buf->ac") (fromMaybe (CELit name1) mcopy) emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." e2' <- compile' (Const accname `SCons` env) e2 @@ -1044,16 +1223,16 @@ compile' env = \case when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" forM_ [0 .. fromSNat n - 1] $ \j -> do - emit $ SIf (CEBinop (CELit (s ++ ".buf->sh[" ++ show j ++ "]")) + emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) "!=" - (CELit (d ++ ".buf->sh[" ++ show j ++ "]"))) + (CELit (d ++ ".sh[" ++ show j ++ "]"))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ d ++ ".buf" ++ - concat [", " ++ d ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ ", " ++ s ++ ".buf" ++ - concat [", " ++ s ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ "); " ++ "return false;") mempty @@ -1110,12 +1289,12 @@ compile' env = \case let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ v ++ ".buf" ++ - concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ "); " ++ "return false;") @@ -1268,21 +1447,21 @@ toLinearIdx SZ _ _ = CELit "0" toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") toLinearIdx (SS n) arrvar idxvar = CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) - "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n))))) + "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) "+" (CELit (idxvar ++ ".b")) -- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr -- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] -- fromLinearIdx (SS n) arrvar idxvar = do -- name <- genName --- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) -- _ data AllocMethod = Malloc | Calloc deriving (Show) -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" @@ -1297,9 +1476,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) emit $ SVarDecl True strname arrname $ CEStruct strname [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] - Calloc -> CECall "calloc_instr" [nbytesExpr])] - forM_ (zip shape [0::Int ..]) $ \(dim, i) -> - emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim + Calloc -> CECall "calloc_instr" [nbytesExpr]) + ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") when debugRefc $ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" @@ -1310,16 +1488,16 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] compileShapeQuery (SS n) var = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", compileShapeQuery n var) - ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] + ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) -- | Takes a variable name for the array, not the buffer. compileArrShapeComponents :: SNat n -> String -> [CExpr] compileArrShapeComponents n var = - [CELit (var ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] indexTupleComponents :: SNat n -> String -> [CExpr] indexTupleComponents = \n var -> map CELit (toList (go n var)) @@ -1338,6 +1516,9 @@ shapeTupFromLitVars = \n -> go n . reverse go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @CompM $ CECall cop [e1] @@ -1410,7 +1591,7 @@ compileExtremum nameBase opName operator env e = do lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" @@ -1481,7 +1662,7 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - SMTArr n t | not (hasArrays (fromSMTy t)) -> do + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" emit $ SVarDeclUninit toptyname name @@ -1490,7 +1671,7 @@ copyForWriting topty var = case topty of let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" emit $ SVerbatim $ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ - concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ ");" emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) @@ -1501,8 +1682,7 @@ copyForWriting topty var = case topty of in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ printCExpr 0 databytes ");"]) @@ -1519,8 +1699,7 @@ copyForWriting topty var = case topty of name <- genName emit $ SVarDecl False toptyname name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain diff --git a/src/Compile/Exec.hs b/src/CHAD/Compile/Exec.hs index 9b9fb15..ffe5661 100644 --- a/src/Compile/Exec.hs +++ b/src/CHAD/Compile/Exec.hs @@ -1,6 +1,5 @@ {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TupleSections #-} -module Compile.Exec ( +module CHAD.Compile.Exec ( KernelLib, buildKernel, callKernelFun, @@ -30,7 +29,7 @@ debug = False -- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) -buildKernel :: String -> String -> IO KernelLib +buildKernel :: String -> String -> IO (KernelLib, String) buildKernel csource funname = do template <- (++ "/tmp.chad.") <$> getTempDir path <- mkdtemp template @@ -42,7 +41,9 @@ buildKernel csource funname = do ,"-o", outso, "-" ,"-Wall", "-Wextra" ,"-Wno-unused-variable", "-Wno-unused-but-set-variable" - ,"-Wno-unused-parameter", "-Wno-unused-function"] + ,"-Wno-unused-parameter", "-Wno-unused-function" + ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives + ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource -- Print the source before the GCC output. @@ -50,11 +51,6 @@ buildKernel csource funname = do ExitSuccess -> return () ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>" - when (not (null gccStdout)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stdout: <<<\n" ++ gccStdout ++ ">>>" - when (not (null gccStderr)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stderr: <<<\n" ++ gccStderr ++ ">>>" - case ec of ExitSuccess -> return () ExitFailure{} -> do @@ -71,7 +67,7 @@ buildKernel csource funname = do _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" dlclose dl) - return (KernelLib ref) + return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr) foreign import ccall "dynamic" wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () diff --git a/src/Data.hs b/src/CHAD/Data.hs index e6978c8..8c7605c 100644 --- a/src/Data.hs +++ b/src/CHAD/Data.hs @@ -8,7 +8,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl), If) where +module CHAD.Data (module CHAD.Data, (:~:)(Refl), If) where import Data.Functor.Product import Data.GADT.Compare @@ -18,7 +18,7 @@ import Data.Type.Bool (If) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) -import Lemmas (Append) +import CHAD.Lemmas (Append) data Dict c where diff --git a/src/Data/VarMap.hs b/src/CHAD/Data/VarMap.hs index 2712b08..a0d7617 100644 --- a/src/Data/VarMap.hs +++ b/src/CHAD/Data/VarMap.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} -module Data.VarMap ( +module CHAD.Data.VarMap ( VarMap, empty, insert, @@ -20,16 +21,16 @@ module Data.VarMap ( import Prelude hiding (lookup) -import qualified Data.Map.Strict as Map +import Data.Map.Strict qualified as Map import Data.Map.Strict (Map) import Data.Maybe (mapMaybe) import Data.Some -import qualified Data.Vector.Storable as VS +import Data.Vector.Storable qualified as VS import Unsafe.Coerce -import AST.Env -import AST.Types -import AST.Weaken +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken type role VarMap _ nominal -- ensure that 'env' is not phantom diff --git a/src/CHAD.hs b/src/CHAD/Drev.hs index 143376a..bfa964b 100644 --- a/src/CHAD.hs +++ b/src/CHAD/Drev.hs @@ -3,13 +3,12 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} @@ -22,7 +21,7 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module CHAD ( +module CHAD.Drev ( drev, freezeRet, CHADConfig(..), @@ -35,22 +34,21 @@ module CHAD ( import Data.Functor.Const import Data.Some import Data.Type.Equality (type (==), testEquality) -import GHC.Stack (HasCallStack) -import Analysis.Identity (ValId(..), validSplitEither) -import AST -import AST.Bindings -import AST.Count -import AST.Env -import AST.Sparse -import AST.Weaken.Auto -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap -import Data.VarMap (VarMap) -import Lemmas +import CHAD.Analysis.Identity (ValId(..), validSplitEither) +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.AST.Count +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.Weaken.Auto +import CHAD.Data +import CHAD.Data.VarMap qualified as VarMap +import CHAD.Data.VarMap (VarMap) +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types +import CHAD.Lemmas ------------------------------ TAPES AND BINDINGS ------------------------------ @@ -147,7 +145,7 @@ growRecon t ts (Reconstructor unfbs bs) -- Add a 'fst' at the bottom of the builder stack. -- First we have to weaken most of 'bs' to skip one more binding in the -- unfolder stack above it. - (BPush (fst (weakenBindings weakenExpr + (BPush (fst (weakenBindingsE (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) (WSink :: env :> (Tape (t : ts) : env))) bs)) (t @@ -197,14 +195,14 @@ buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) -- incidentally also add a bunch of additional bindings, namely 'Reverse -- (TapeUnfoldings binds)', so the calling code just has to skip those in -- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) - -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) +reconstructBindings :: SList STy binds + -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = - let Reconstructor unf build = buildReconstructor binds - in (fst $ weakenBindings weakenExpr (WIdx tape) - (bconcat (mapBindings fromUnfExpr unf) build) - ,sreverse (stapeUnfoldings binds)) +reconstructBindings binds = + (\tape -> let Reconstructor unf build = buildReconstructor binds + in fst $ weakenBindingsE (WIdx tape) + (bconcat (mapBindings fromUnfExpr unf) build) + ,sreverse (stapeUnfoldings binds)) ---------------------------------- DERIVATIVES --------------------------------- @@ -485,11 +483,6 @@ expandSubenvZeros w (SCons t ts) (SENo sub) e = (expandSubenvZeros (WPop w) ts sub e) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) -assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" - --------------------------------- ACCUMULATORS --------------------------------- @@ -664,7 +657,7 @@ weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list weakenRets w (Rets binds tapesub list) = - let (binds', _) = weakenBindings weakenExpr w binds + let (binds', _) = weakenBindingsE w binds in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. @@ -705,7 +698,7 @@ freezeRet :: Descr env sto -> Ret env sto (D2 t) t -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = - let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 + let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) library = #d (auto1 @(D2 t)) @@ -782,17 +775,17 @@ drev des accumMap sd = \case (ENil ext) ELet _ (rhs :: Expr _ _ a) body - | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 + , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + Ret (bconcat (rhs0 `bpush` rhs1) body0') (subenvConcat subtapeRHS subtapeBody) (weakenExpr wbody0' body1) subBoth @@ -881,8 +874,8 @@ drev des accumMap sd = \case ECase _ e (a :: Expr _ _ t) b | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e - , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge - , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge , let (bindids1, bindids2) = validSplitEither (extOf e) , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 <- drevScoped des accumMap t1 storage1 bindids1 sd a @@ -900,8 +893,8 @@ drev des accumMap sd = \case , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) - , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 + , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 + , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) @@ -911,18 +904,16 @@ drev des accumMap sd = \case -> subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> - Ret (e0 `BPush` - (tPrimal, - ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))) + Ret (e0 `bpush` ECase ext e1 + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut (elet (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ - in letBinds rebinds $ + (let (rebinds, prerebinds) = reconstructBindings subtapeListA + in letBinds (rebinds IZ) $ ELet ext (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ elet @@ -937,8 +928,8 @@ drev des accumMap sd = \case a2) $ EPair ext (sAB_A $ EFst ext (evar IZ)) (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) - (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ - in letBinds rebinds $ + (let (rebinds, prerebinds) = reconstructBindings subtapeListB + in letBinds (rebinds IZ) $ ELet ext (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ elet @@ -976,7 +967,7 @@ drev des accumMap sd = \case (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> - Ret (e0 `BPush` (d1 (typeOf e), e1)) + Ret (e0 `bpush` e1) (SEYesR subtape) (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub @@ -984,15 +975,15 @@ drev des accumMap sd = \case (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - ECustom _ _ tb storety srce pr du a b + ECustom _ _ tb _ srce pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> case isDense (d2M (typeOf srce)) sd of Just Refl -> - Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) + `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) (SEYesR (SENo (SENo (SENo bsubtape)))) (EFst ext (EVar ext (typeOf pr) (IS IZ))) bsub @@ -1000,9 +991,9 @@ drev des accumMap sd = \case weakenExpr (WCopy (WSink .> WSink)) b2) Nothing -> - Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))) + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) (SEYesR (SENo (SENo bsubtape))) (EFst ext (EVar ext (typeOf pr) IZ)) bsub @@ -1023,7 +1014,7 @@ drev des accumMap sd = \case (subenvAll (desD1E usedDes)) (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) (subenvCompose subMergeUsed' sub) - (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ weakenExpr (autoWeak (#d (auto1 @sd) &. #shbinds (bindingsBinds e0) @@ -1050,80 +1041,171 @@ drev des accumMap sd = \case (subenvNone (d2e (select SMerge des))) (ENil ext) - EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) + EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) | SpArr @_ @sdElt sdElt <- sd - , let eltty = typeOf orige + , let eltty = typeOf ef , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> - deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> - case assertSubenvEmpty sub of { Refl -> - case lemAppendNil @e_binds of { Refl -> - let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - Ret (mergePrimalBindings - `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #propr :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) - w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) - in EPair ext (weakenExpr w e1) (collectexpr w'))) - `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) - (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) - (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in + drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + let library = #ix (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d (auto1 @sdElt) + &. #tape (auto1 @e_tape) + &. #pro (d2ace provars) + &. #d2acEnv (d2ace (select SAccum des)) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim e_tape)) in + Ret (proPrimalBinds + `bpush` weakenExpr (wSinks (d1e provars)) + (EBuild ext ndim + (drevPrimal des she) + (letBinds e0 $ + EPair ext e1 e1tape)) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) + (SEYesR (SENo (subenvAll (d1e provars)))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @sdElt) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim sdElt)) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #propr (d1e envPro) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - }}} + wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ + EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + e2) + + EMap _ ef (earr :: Expr _ _ (TArr n a)) + | SpArr sdElt <- sd + , let STArr ndim t1 = typeOf earr + t2 = typeOf ef -> + drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 -> + case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds + ttape = typeOf ef1tape + library = #d1env (desD1E des) + &. #a0 (bindingsBinds ea0) + &. #atapebinds (subList (bindingsBinds ea0) easubtape) + &. #propr (d1e provars) + &. #x (d1 t1 `SCons` SNil) + &. #parr (STArr ndim (d1 t1) `SCons` SNil) + &. #tapearr (STArr ndim ttape `SCons` SNil) + &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil) + &. #dy (applySparse sdElt (d2 t2) `SCons` SNil) + &. #tape (ttape `SCons` SNil) + &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil) + &. #d2acEnv (d2ace (select SAccum des)) + &. #pro (d2ace provars) + in + subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a -> + Ret (bconcat ea0 proPrimalBinds' + `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1 + `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env)) + (letBinds ef0 $ + EPair ext ef1 ef1tape)) + (EVar ext (STArr ndim (d1 t1)) IZ) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ)) + (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars)))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ))) + subfa + (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout) $ + emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $ + elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv) + (#tape :++: #dy :++: #dytape :++: #pro :++: layout)) + ef2) + (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ)) + (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $ + plus_f_a + (ESnd ext (evar IZ)) + (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout)) + (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ)) + ea2))) + } + + EFold1Inner _ commut origef exâ‚€ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsxâ‚€a subtapexâ‚€a (RetPair exâ‚€1 subxâ‚€ exâ‚€2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) exâ‚€) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsxâ‚€a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #pxâ‚€ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) + &. #pzi (auto1 @(ZeroInfo (D2 elt))) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #xâ‚€abinds (bindingsBinds bindsxâ‚€a) + &. #fbinds (bindingsBinds ef0) + &. #xâ‚€atapebinds (subList (bindingsBinds bindsxâ‚€a) subtapexâ‚€a) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace provars) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#xâ‚€abinds :++: #d1env) ((#propr :++: #xâ‚€abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subxâ‚€ suba $ \subxâ‚€a _ _ plus_xâ‚€_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subxâ‚€a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subxâ‚€af _ _ plus_xâ‚€a_f -> + Ret (bconcat bindsxâ‚€a proPrimalBinds' + `bpush` weakenExpr wOverPrimalBindings exâ‚€1 + `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) + `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (let layout = #xy :++: #parr :++: #pzi :++: #pxâ‚€ :++: (#propr :++: #xâ‚€abinds) :++: #d1env in + weakenExpr (autoWeak library (#xy :++: #d1env) layout) + (letBinds ef0 $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + ef1 + (EPair ext + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ)) + ef1tape))) + (EVar ext (d1 eltty) (IS (IS IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapexâ‚€a (subenvAll (d1e provars))))))) + (EFst ext (EVar ext primalTy IZ)) + subxâ‚€af + (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #xâ‚€atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ + plus_xâ‚€a_f + (plus_xâ‚€_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#xâ‚€atapebinds :++: #d2acEnv) layout1)) + exâ‚€2) + (weakenExpr (WCopy (autoWeak library (#xâ‚€atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) EUnit _ e | SpArr sdElt <- sd @@ -1147,9 +1229,8 @@ drev des accumMap sd = \case (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) sub (ELet ext (EFold1Inner ext Commut - (sparsePlus (d2M eltty) sdElt' - (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) - (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (inj2 (ENil ext)) (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ weakenExpr (WCopy WSink) e2) @@ -1172,8 +1253,8 @@ drev des accumMap sd = \case | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) (SEYesR (SENo subtape)) (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) (weakenExpr (WSink .> WSink) ei1)) @@ -1191,9 +1272,9 @@ drev des accumMap sd = \case , let tIxN = tTup (sreplicate n tIx) -> sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))) + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (typeOf e1) IZ) + `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) (SEYesR (SEYesR (SENo subtape))) (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) @@ -1225,8 +1306,8 @@ drev des accumMap sd = \case | SpArr sd' <- sd , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> - Ret (e0 `BPush` (STArr (SS n) t, e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub @@ -1238,8 +1319,47 @@ drev des accumMap sd = \case EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e - -- These should be the next to be implemented, I think - EFold1Inner{} -> err_unsupported "EFold1Inner" + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" @@ -1255,10 +1375,14 @@ drev des accumMap sd = \case EPlus{} -> err_monoid EOneHot{} -> err_monoid + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) @@ -1274,8 +1398,8 @@ deriv_extremum extremum des accumMap sd e , let tIxN = tTup (sreplicate (SS n) tIx) = sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) + Ret (e0 `bpush` e1 + `bpush` extremum (EVar ext at IZ)) (SEYesR (SEYesR subtape)) (EVar ext at' IZ) sub @@ -1367,6 +1491,88 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of , Refl <- lemAppendNil @tapebinds -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) + => Descr env sto + -> VarMap Int (D2AcE (Select env sto "accum")) + -> (STy a, Storage s) + -> Sparse (D2 t) dt + -> Expr ValId (a : env) t + -> (forall provars shbinds tape d2a'. + SList STy provars + -> Subenv (D2E (Select env sto "merge")) (D2E provars) + -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) + -> Bindings Ex (D1 a : D1E env) shbinds + -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) + -> Ex (Append shbinds (D1 a : D1E env)) tape + -> Sparse (D2 a) d2a' + -> (forall env' b. + D1E provars :> env' + -> Ex (Append (D2AcE provars) env') b + -> Ex ( env') (TPair b (Tup (D2E provars)))) + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> r) + -> r +drevLambda des accumMap (argty, argsto) sd origef k = + let t = typeOf origef in + deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + case prf1 prodes argty argsto of { Refl -> + case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> + let library = #fbinds (bindingsBinds ef0) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #arg (d1 argty `SCons` SNil) + &. #d (applySparse sd (d2 t) `SCons` SNil) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #efPrerebinds efPrerebinds in + k envPro + (subenvD2E (subenvCompose subMergeUsed proSub)) + mergePrimalBindings + (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) + (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #arg :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) + argSp + (\wpro1 body -> + uninvertTup (d2e envPro) (typeOf body) $ + makeAccumulators wpro1 envPro $ + body) + (letBinds (efRebinds IZ) $ + weakenExpr + (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + (getSparseArg ef2)) + }} + where + extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) + => proxy env sto -> proxy2 a -> Storage s + -- if s == "merge", this simplifies to SubenvS '[D2 a] t' + -- if s == "discr", this simplifies to SubenvS '[] t' + -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' + -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id + extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + + prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s + -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" + prf1 _ _ SMerge = Refl + prf1 _ _ SDiscr = Refl + -- TODO: proper primal-only transform that doesn't depend on D1 = Id drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) drevPrimal des e diff --git a/src/CHAD/Accum.hs b/src/CHAD/Drev/Accum.hs index 7212232..6f25f11 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Drev/Accum.hs @@ -1,13 +1,13 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} -- | TODO this module is a grab-bag of random utility functions that are shared --- between CHAD and CHAD.Top. -module CHAD.Accum where +-- between CHAD.Drev and CHAD.Drev.Top. +module CHAD.Drev.Accum where -import AST -import CHAD.Types -import Data -import AST.Env +import CHAD.AST +import CHAD.Data +import CHAD.Drev.Types +import CHAD.AST.Env d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) @@ -44,6 +44,11 @@ d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" +-- The weakening is necessary because we need to initialise the created +-- accumulators with zeros. Those zeros are deep and need full primals. This +-- means, in the end, that primals corresponding to environment entries +-- promoted to an accumulator with accumPromote in CHAD need to be stored for +-- the dual. makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators _ SNil e = e makeAccumulators w (t `SCons` envpro) e = diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs index 49ae0e6..5a90303 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/Drev/EnvDescr.hs @@ -7,18 +7,18 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.EnvDescr where +module CHAD.Drev.EnvDescr where import Data.Kind (Type) import Data.Some import GHC.TypeLits (Symbol) -import Analysis.Identity (ValId(..)) -import AST.Env -import AST.Types -import AST.Weaken -import CHAD.Types -import Data +import CHAD.Analysis.Identity (ValId(..)) +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types type Storage :: Symbol -> Type diff --git a/src/CHAD/Top.hs b/src/CHAD/Drev/Top.hs index 484779e..65b4dee 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Drev/Top.hs @@ -1,6 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} @@ -8,20 +8,20 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.Top where +module CHAD.Drev.Top where -import Analysis.Identity -import AST -import AST.Env -import AST.Sparse -import AST.SplitLets -import AST.Weaken.Auto -import CHAD -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap +import CHAD.Analysis.Identity +import CHAD.AST +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.SplitLets +import CHAD.AST.Weaken.Auto +import CHAD.Data +import CHAD.Data.VarMap qualified as VarMap +import CHAD.Drev +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types type family MergeEnv env where @@ -43,8 +43,8 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) - else k (des `DPush` (t, Nothing, SMerge)) + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) reassembleD2E :: Descr env sto -> D1E env :> env' diff --git a/src/CHAD/Types.hs b/src/CHAD/Drev/Types.hs index 44ac20e..367a974 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Drev/Types.hs @@ -2,11 +2,11 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.Types where +module CHAD.Drev.Types where -import AST.Accum -import AST.Types -import Data +import CHAD.AST.Accum +import CHAD.AST.Types +import CHAD.Data type family D1 t where diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs index 888fed4..019119c 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -1,14 +1,14 @@ {-# LANGUAGE GADTs #-} -module CHAD.Types.ToTan where +module CHAD.Drev.Types.ToTan where import Data.Bifunctor (bimap) -import Array -import AST.Types -import CHAD.Types -import Data -import ForwardAD -import Interpreter.Rep +import CHAD.Array +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.Interpreter.Rep toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) diff --git a/src/Example.hs b/src/CHAD/Example.hs index b320ead..34ff889 100644 --- a/src/Example.hs +++ b/src/CHAD/Example.hs @@ -7,26 +7,44 @@ {-# LANGUAGE TypeApplications #-} {-# OPTIONS -Wno-unused-imports #-} -module Example where - -import Array -import AST -import AST.Pretty -import AST.UnMonoid -import CHAD -import CHAD.Top -import ForwardAD -import Interpreter -import Language -import Simplify +module CHAD.Example where import Debug.Trace -import Example.Types + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Data +import CHAD.Drev +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.Interpreter +import CHAD.Language as L +import CHAD.Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) +pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +pipeline config term + | Dict <- styKnown (d2 (typeOf term)) = + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ chad' config knownEnv $ + simplifyFix $ term + +-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures +pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () +pipeline' config term + | Dict <- styKnown (d2 (typeOf term)) = + pprintExpr (pipeline config term) + + bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) @@ -166,3 +184,14 @@ neuralGo = (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) + +-- The build body uses free variables in a non-linear way, so their primal +-- values are required in the dual of the build. Thus, compositionally, they +-- are stored in the tape from each individual lambda invocation. This results +-- in n copies of y and z, where only one copy would have sufficed. +exUniformFree :: Ex '[R, I64] R +exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ + let_ #y (#x * 2) $ + let_ #z (#x * 3) $ + idx0 $ sum1i $ + build1 #n $ #i :-> #y * #z + toFloat_ #i diff --git a/src/Example/GMM.hs b/src/CHAD/Example/GMM.hs index 206e534..18641e8 100644 --- a/src/Example/GMM.hs +++ b/src/CHAD/Example/GMM.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeApplications #-} -module Example.GMM where +module CHAD.Example.GMM where -import Example.Types -import Language +import CHAD.Data (SList(..)) +import CHAD.Example.Types +import CHAD.Language diff --git a/src/Example/Types.hs b/src/CHAD/Example/Types.hs index d63159b..1e2f72d 100644 --- a/src/Example/Types.hs +++ b/src/CHAD/Example/Types.hs @@ -1,8 +1,8 @@ {-# LANGUAGE DataKinds #-} -module Example.Types where +module CHAD.Example.Types where -import AST -import Data +import CHAD.AST +import CHAD.Data type R = TScal TF64 diff --git a/src/ForwardAD.hs b/src/CHAD/ForwardAD.hs index b353def..0ae88ce 100644 --- a/src/ForwardAD.hs +++ b/src/CHAD/ForwardAD.hs @@ -4,24 +4,25 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD where +module CHAD.ForwardAD where import Data.Bifunctor (bimap) +import Data.Foldable (fold) import System.IO.Unsafe -- import Debug.Trace --- import AST.Pretty +-- import CHAD.AST.Pretty -import Array -import AST -import Compile -import Data -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep +import CHAD.Array +import CHAD.AST +import CHAD.Compile +import CHAD.Data +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep --- | Tangent along a type (coincides with cotangent for these types) +-- | Tangent along a type (coincides with the cotangent, t'CHAD.Drev.Types.D2', for these types) type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) @@ -89,7 +90,7 @@ tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x -tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x +tanScalars (STArr _ t) x = fold $ arrayMap (tanScalars t) x tanScalars (STScal STI32) _ = [] tanScalars (STScal STI64) _ = [] tanScalars (STScal STF32) x = [realToFrac x] @@ -254,8 +255,10 @@ makeFwdADArtifactInterp env expr = in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) {-# NOINLINE makeFwdADArtifactCompile #-} -makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) -makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr) +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) +makeFwdADArtifactCompile env expr = do + (fun, output) <- compile (dne env) (dfwdDN expr) + return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) diff --git a/src/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs index 3ab08af..540ec2b 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/CHAD/ForwardAD/DualNumbers.hs @@ -1,11 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -- I want to bring various type variables in scope using type annotations in @@ -14,14 +13,14 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD.DualNumbers ( +module CHAD.ForwardAD.DualNumbers ( dfwdDN, DN, DNS, DNE, dn, dne, ) where -import AST -import Data -import ForwardAD.DualNumbers.Types +import CHAD.AST +import CHAD.Data +import CHAD.ForwardAD.DualNumbers.Types dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) @@ -153,10 +152,11 @@ dfwdDN = \case (EConstArr ext n t x) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e - pairty = (STPair (STScal t) (STScal t)) + pairty = STPair (STScal t) (STScal t) in scalTyCase t (ELet ext (dfwdDN e) $ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) @@ -168,6 +168,9 @@ dfwdDN = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) + EReshape _ n esh e + | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) @@ -194,9 +197,13 @@ dfwdDN = \case EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" + err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" deriv_extremum :: ScalIsNumeric t ~ True => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) @@ -220,4 +227,4 @@ dfwdDN = \case (EFst ext (EVar ext tIxN (IS IZ))))))) (ESnd ext (EVar ext t2 (IS IZ))) (zeroScalarConst t)))) - (EMaximum1Inner ext (dfwdDN e)) + (extremum (dfwdDN e)) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs index dcacf5f..5d5dd9e 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs @@ -1,10 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD.DualNumbers.Types where +module CHAD.ForwardAD.DualNumbers.Types where -import AST.Types -import Data +import CHAD.AST.Types +import CHAD.Data -- | Dual-numbers transformation diff --git a/src/Interpreter.hs b/src/CHAD/Interpreter.hs index ffc2929..6410b5b 100644 --- a/src/Interpreter.hs +++ b/src/CHAD/Interpreter.hs @@ -5,40 +5,40 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Interpreter ( +module CHAD.Interpreter ( interpret, interpretOpen, Value(..), ) where import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.Int (Int64) import Data.IORef +import Data.Tuple (swap) import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace -import Array -import AST -import AST.Pretty -import AST.Sparse.Types -import Data -import Interpreter.Rep +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Interpreter.Rep newtype AcM s a = AcM { unAcM :: IO a } @@ -113,13 +113,16 @@ interpret'Rec env = \case EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) + EMap _ a b -> do + let STArr _ t = typeOf b + arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b EFold1Inner _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e @@ -143,6 +146,50 @@ interpret'Rec env = \case sh `ShCons` n = arrayShape arr numericIsNum t $ return $ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EReshape _ dim esh e -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh + arr <- interpret' env e + return $ arrayReshape sh arr + EZip _ a b -> do + arr1 <- interpret' env a + arr2 <- interpret' env b + let sh = arrayShape arr1 + when (sh /= arrayShape arr2) $ + error "Interpreter: mismatched shapes in EZip" + return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) @@ -411,3 +458,11 @@ ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y diff --git a/src/Interpreter/Accum.hs b/src/CHAD/Interpreter/Accum.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/Accum.hs +++ b/src/CHAD/Interpreter/Accum.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/AccumOld.hs b/src/CHAD/Interpreter/AccumOld.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/AccumOld.hs +++ b/src/CHAD/Interpreter/AccumOld.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs index 1682303..fadc6be 100644 --- a/src/Interpreter/Rep.hs +++ b/src/CHAD/Interpreter/Rep.hs @@ -3,7 +3,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -module Interpreter.Rep where +module CHAD.Interpreter.Rep where import Control.DeepSeq import Data.Coerce (coerce) @@ -12,10 +12,10 @@ import Data.Foldable (toList) import Data.IORef import GHC.Exts (withDict) -import Array -import AST -import AST.Pretty -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.Data type family Rep t where diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs new file mode 100644 index 0000000..6621eef --- /dev/null +++ b/src/CHAD/Language.hs @@ -0,0 +1,423 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Language ( + -- * Named expressions + fromNamed, + NExpr, NFun, + + -- * Functions + lambda, + body, + inline, + (.$), + + -- * Basic language constructs + let_, + pair, fst_, snd_, nil, + inl, inr, case_, + nothing, just, maybe_, + + -- * Array operations + constArr_, + build1, build2, build, + map_, + fold1i, fold1i', + sum1i, + unit, + replicate1i, + maximum1i, minimum1i, + reshape, + fold1iD1, fold1iD1', + fold1iD2, + + -- * Scalar operations + -- | Note that 'NExpr' is also an instance of some numeric classes like 'Num' and 'Floating'. + const_, + idx0, + (!), + shape, + length_, + error_, + (.==), (.<), (CHAD.Language..>), (.<=), (.>=), + not_, and_, or_, + mod_, round_, toFloat_, idiv, + + -- * Control flow + if_, + + -- * Special operations + custom, + recompute, + with, accum, accumS, + oper, oper2, + + -- * Helper types + (:->)(..), + + -- * Reexports + TIx, + Lookup, + Ex, + Ty(..), + SNat(..), Nat(..), N0, N1, N2, N3, +) where + +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.Language.AST + + +-- | Helper type, used for e.g. 'case_' and 'build'. +data a :-> b = a :-> b + deriving (Show) +infixr 0 :-> + + +-- | See 'fromNamed' for a usage example. +body :: NExpr env t -> NFun env env t +body = NBody + +-- | See 'fromNamed' for a usage example. +lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda = NLam + +-- | Inline a function here, with the given list of expressions as arguments. +-- While this is a normal 'SList', the @params@ list is reversed from the +-- natural argument order of the function; the '(.$)' helper operator serves to +-- "fix" the order. +-- +-- @ +-- let fun = 'lambda' \@(TScal TF64) #x $ 'lambda' \@(TScal TBool) #b $ 'body' $ if_ #b #x (#x + 1) +-- in 'inline' fun ('SNil' .$ 16 .$ 'const_' True) +-- @ +-- +-- Note that no 'const_' is needed for the @16@, because 'NExpr' implements +-- 'Num'. +inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t +inline = inlineNFun + +-- | Helper for constructing the argument list for 'inline'; +-- @(.$) = flip 'SCons'@. See 'inline'. +(.$) :: SList f list -> f a -> SList f (a : list) +(.$) = flip SCons + + +-- | The first 'Var' argument is the left-hand side of this let-binding. For example: +-- +-- @ +-- 'fromNamed' $ 'lambda' \@(TScal TI64) #a $ 'body' $ +-- 'let_' #x (#a + 1) $ +-- #x * #a +-- @ +-- +-- This produces an expression of type @'Ex' '[TScal TI64] (TScal TI64)@ that +-- corresponds to the Haskell code @\\a -> let x = a + 1 in x * a@. +let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t +let_ = NELet + +pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) +pair = NEPair + +fst_ :: NExpr env (TPair a b) -> NExpr env a +fst_ = NEFst + +snd_ :: NExpr env (TPair a b) -> NExpr env b +snd_ = NESnd + +nil :: NExpr env TNil +nil = NENil + +inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) +inl = NEInl knownTy + +inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) +inr = NEInr knownTy + +-- | A @case@ expression on @Either@s. For example, the following expression +-- will evaluate to 10 + 1 = 11: +-- +-- @ +-- 'case_' ('inl' 10) +-- (#x :-> #x + 1) +-- (#y :-> #y * 2) +-- @ +case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c +case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 + +nothing :: KnownTy a => NExpr env (TMaybe a) +nothing = NENothing knownTy + +just :: NExpr env a -> NExpr env (TMaybe a) +just = NEJust + +-- | Analogue of the 'Prelude.maybe' function in the Haskell Prelude: +-- +-- @ +-- 'maybe_' 2 (#x :-> #x * 3) (...) +-- @ +-- +-- will return 2 if @(...)@ is @Nothing@ and @x + 3@ if it is @Just x@. +maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b +maybe_ a (v :-> b) c = NEMaybe a v b c + +-- | To construct 'Array' values, see "CHAD.Array". +constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) +constArr_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConstArr knownNat ty x + +-- | Special case of 'build' for 1-dimensional arrays. This produces the array +-- [0.0, 1.0, 2.0]: +-- +-- @ +-- 'build1' 3 (#i :-> 'toFloat_' #i) +-- @ +build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) +build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) + +-- | Special case of 'build' for 2-dimensional arrays. +build2 :: NExpr env TIx -> NExpr env TIx + -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) + -> NExpr env (TArr (S (S Z)) t) +build2 a1 a2 (v1 :-> v2 :-> b) = + NEBuild (SS (SS SZ)) + (pair (pair nil a1) a2) + #idx + (let_ v1 (snd_ (fst_ #idx)) $ + let_ v2 (NEDrop SZ (snd_ #idx)) $ + NEDrop (SS (SS SZ)) b) + +-- | General n-dimensional elementwise array constructor. A 3-dimensional index +-- looks like @((((), i1), i2), i3)@; other dimensionalities are analogous. The +-- innermost dimension (i.e. whose index variable varies the fastest in the +-- standard memory layout) is the right-most index, i.e. @i3@ in 3D example. To +-- create a 10-by-10 table of (row, column) pairs: +-- +-- @ +-- 'build' ('SS' ('SS' 'SZ')) ('pair' ('pair' 'nil' 10) 10) (#i :-> #j :-> 'pair' #i #j) +-- @ +build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) +build n a (v :-> b) = NEBuild n a v b + +map_ :: forall n a b env name. (KnownNat n, KnownTy a) + => (Var name a :-> NExpr ('(name, a) : env) b) + -> NExpr env (TArr n a) -> NExpr env (TArr n b) +map_ (v :-> a) b = NEMap v a b + +-- | Fold over the innermost dimension of an array, thus reducing its dimensionality by one. +fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | The underlying AST constructor for a fold takes a function with /one/ +-- argument: a pair of inputs. 'fold1i'' directly returns this AST constructor +-- in case it is helpful for testing. The 'fold1i' function is a convenience +-- wrapper around 'fold1i''. +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 + +sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +sum1i e = NESum1Inner e + +unit :: NExpr env t -> NExpr env (TArr Z t) +unit = NEUnit + +replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) +replicate1i n a = NEReplicate1Inner n a + +maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +maximum1i e = NEMaximum1Inner e + +minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +minimum1i e = NEMinimum1Inner e + +reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) +reshape = NEReshape + +-- | 'fold1iD1'' with a curried combination function. +fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | Primal of a fold. Not supported in the input program for reverse differentiation. +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 + +-- | Reverse pass of a fold. Not supported in the input program for reverse differentiation. +fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) + -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) +fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 + +const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) +const_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty x + +idx0 :: NExpr env (TArr Z t) -> NExpr env t +idx0 = NEIdx0 + +-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) +-- (.!) = NEIdx1 +-- infixl 9 .! + +-- | Index an array. Note that the index is a tuple, just like the argument to +-- the function in 'build'. To index a 2-dimensional array @a@ at row @i@ and +-- column @j@, write @a '!' 'pair' ('pair' 'nil' i) j@. +(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx +infixl 9 ! + +shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) +shape = NEShape + +-- | Convenience special case of 'shape' for single-dimensional arrays. +length_ :: NExpr env (TArr N1 t) -> NExpr env TIx +length_ e = snd_ (shape e) + +oper :: SOp a t -> NExpr env a -> NExpr env t +oper = NEOp + +oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t +oper2 op a b = NEOp op (pair a b) + +error_ :: KnownTy t => String -> NExpr env t +error_ s = NEError knownTy s + +-- | Specify a custom reverse derivative for a subexpression. Morally, the type +-- of this combinator should be read as follows: +-- +-- @ +-- custom :: (a -> b -> t) -- normal semantics +-- -> (D1 a -> D1 b -> (D1 t, tape)) -- forward pass +-- -> (tape -> D2 t -> D2 b) -- reverse pass +-- -> a -> b -- arguments +-- -> t -- result +-- @ +-- +-- In normal evaluation, or when forward-differentiating, the first argument is +-- taken and the second and third are ignored. When reverse-differentiating +-- using CHAD, however, the /first/ argument is ignored and the second and +-- third arguments are respectively put in the forward and the reverse passes +-- of the derivative program. The @tape@ value may be used to remember primals +-- for the reverse pass. +-- +-- This combinator allows for "inactive" and "active" inputs to the operation; +-- derivatives to the "inactive" input are not propagated. The active input +-- (whose derivatives /are/ propagated) has type @b@; the inactive input has +-- type @a@. +-- +-- No accumulators are allowed inside @a@, @b@ and @tape@. +custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) + -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) + -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) + -> NExpr env a -> NExpr env b + -> NExpr env t +custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = + NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 + +-- | Semantically the identity, but when reverse differentiating using CHAD, +-- the contained expression is recomputed in the reverse pass. This is a +-- light-weight form of checkpointing, with the goal of reducing the number +-- primal values being stored and thus reducing memory use and memory traffic. +-- +-- Note that free variables of the contained expression do still need to be +-- stored, as we do need to be able to recompute the expression in the reverse +-- pass. +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + +-- | Introduce an accumulator. The initial value is not allowed to be sparse! +-- See 'CHAD.AST.EWith'. Not supported in the input program for reverse +-- differentiation. +with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) +with a (n :-> b) = NEWith (knownMTy @t) a n b + +-- | Accumulate to an accumulator. Not supported in the input program for +-- reverse differentiation. +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +-- | Accumulate to an accumulator with additional sparsity. Not supported in +-- the input program for reverse differentiation. +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c + + +(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .== b = oper (OEq knownScalTy) (pair a b) +infix 4 .== + +(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .< b = oper (OLt knownScalTy) (pair a b) +infix 4 .< + +(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>) = flip (.<) +infix 4 .> + +(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .<= b = oper (OLe knownScalTy) (pair a b) +infix 4 .<= + +(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>=) = flip (.<=) +infix 4 .>= + +not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) +not_ = oper ONot + +and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +and_ = oper2 OAnd +infixr 3 `and_` + +or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +or_ = oper2 OOr +infixr 2 `or_` + +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` + +-- | The first alternative is the True case; the second is the False case. +if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t +if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) + +round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) +round_ = oper ORound64 + +toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) +toFloat_ = oper OToFl64 + +idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) +idiv = oper2 (OIDiv knownScalTy) +infixl 7 `idiv` diff --git a/src/Language/AST.hs b/src/CHAD/Language/AST.hs index be98ccf..502a2b3 100644 --- a/src/Language/AST.hs +++ b/src/CHAD/Language/AST.hs @@ -4,7 +4,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -12,20 +14,22 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module Language.AST where +module CHAD.Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) -import Array -import AST -import AST.Sparse.Types -import CHAD.Types -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types +-- | A named expression: variables have names, not De Bruijn indices. +-- Otherwise essentially identical to 'Expr'. type NExpr :: [(Symbol, Ty)] -> Ty -> Type data NExpr env t where -- lambda calculus @@ -50,12 +54,23 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) + + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) + -> NExpr env t1 + -> NExpr env (TArr (S n) t1) + -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) + NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) + -> NExpr env (TArr (S n) b) + -> NExpr env (TArr n t2) + -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) @@ -86,11 +101,23 @@ data NExpr env t where NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -type family Lookup name env where - Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup name ('(name, t) : env) = t - Lookup name (_ : env) = Lookup name env +-- | Look up the type of a name in a named environment. +type Lookup name env = Lookup1 (name == "_") name env +-- | This curious stack of type families is used instead of normal pattern +-- matching so the decidable boolean predicate "==" is used. This means that +-- introducing evidence of @(name1 == name2) ~ False@ may allow a certain +-- lookup to reduce even if the names in question are not statically known. +-- This flexibility is used with e.g. 'assertSymbolDistinct' in +-- 'CHAD.Language.fold1i'. +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env @@ -142,10 +169,20 @@ data NEnv env where NTop :: NEnv '[] NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) --- | First (outermost) parameter on the outside, on the left. --- * env: environment of this function (grows as you go deeper inside lambdas) --- * env': environment of the body of the function --- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list +-- | A named /function/. These can be used in only two ways: they can be +-- converted to an unnamed 'Expr' using 'fromNamed', and they can be inlined +-- using 'CHAD.Language.inline'. +-- +-- * @env@: environment of this function (smaller than @env'@; grows as you descend under lambdas) +-- * @env'@: environment of the body of the function +-- +-- For example, a function @(\\(x :: a) (y :: b) -> _ :: c)@ with two free +-- variables, @u :: t1@ and @v :: t2@, would be represented with a value of the +-- following type: +-- +-- @ +-- NFun '['("v", t2), '("u", t1)] '['("y", b), '("x", a), '("v", t2), '("u", t1)] c +-- @ data NFun env env' t where NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t NBody :: NExpr env' t -> NFun env' env' t @@ -161,6 +198,41 @@ envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t inlineNFun fun args = NEUnnamed (fromNamed fun) args +-- | Convert a named function to an unnamed expression with free variables, +-- ready for consumption by the rest of this library. The function must be +-- closed (meaning that the function as a whole cannot have free variables), +-- and the arguments of the function are realised as free variables of the +-- resulting expression. Typical usage looks as follows: +-- +-- @ +-- {-# LANGUAGE OverloadedLabels #-} +-- import CHAD.Language +-- 'fromNamed' $ 'CHAD.Language.lambda' \@(TScal TF64) #x $ 'CHAD.Language.lambda' \@(TScal TI64) #i $ 'CHAD.Language.body' $ #x + 'CHAD.Language.toFloat_' #i +-- :: 'Ex' '[TScal TI64, TScal TF64] (TScal TF64) +-- @ +-- +-- The rest of the library generally considers expressions with free variables +-- to stand in for "functions", by considering the free variables as the +-- function's inputs. +-- +-- Note that while environments normally grow to the right (e.g. in type theory +-- notation), as they as type-level lists here, they grow to the /left/. This +-- is why the second (innermost) argument of the example, @i@, ends up at the +-- head of the environment of the constructed expression. +-- +-- __Type applications__: The type applications to 'CHAD.Language.lambda' above +-- are good practice, but not always necessary; if GHC can infer the type of +-- the argument from the body of the expression, the type application is +-- unnecessary. +-- +-- __Variables__: The major element of syntactic sugar in this module is using +-- OverloadedLabels for variable names. Variables are represented in 'NExpr' +-- (and thus 'NFun') using the 'Var' type; you should never have to manually +-- construct a 'Var'. Instead, 'Var' implements 'IsLabel' and as such can be +-- produced with the syntax @#name@, where "name" is the name of the variable. +-- This syntax produces a polymorphic variable reference whose (embedded) type +-- is left to GHC's type inference engine using a 'KnownTy' constraint. See +-- also 'CHAD.Language.let_'. fromNamed :: NFun '[] env t -> Ex (UnName env) t fromNamed = fromNamedFun NTop @@ -199,12 +271,17 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEMap n a b -> EMap ext (lambda val n a) (go b) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) + NEReshape n a b -> EReshape ext n (go a) (go b) + + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) + NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) @@ -261,3 +338,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k diff --git a/src/Lemmas.hs b/src/CHAD/Lemmas.hs index 31a43ed..55ef042 100644 --- a/src/Lemmas.hs +++ b/src/CHAD/Lemmas.hs @@ -4,7 +4,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE AllowAmbiguousTypes #-} -module Lemmas (module Lemmas, (:~:)(Refl)) where +module CHAD.Lemmas (module CHAD.Lemmas, (:~:)(Refl)) where import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) diff --git a/src/Simplify.hs b/src/CHAD/Simplify.hs index 74b6601..ea253d6 100644 --- a/src/Simplify.hs +++ b/src/CHAD/Simplify.hs @@ -12,7 +12,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Simplify ( +module CHAD.Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, ) where @@ -24,13 +24,13 @@ import Data.Monoid (Any(..)) import Debug.Trace -import AST -import AST.Count -import AST.Pretty -import AST.Sparse.Types -import AST.UnMonoid (acPrjCompose) -import Data -import Simplify.TH +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.AST.UnMonoid (acPrjCompose) +import CHAD.Data +import CHAD.Simplify.TH data SimplifyConfig = SimplifyConfig @@ -185,6 +185,21 @@ simplify'Rec = \case ELet _ e1 (ENil _) | STNil <- typeOf e1 -> acted $ simplify' e1 + -- map (\_ -> x) e ~> build (shape e) (\_ -> x) + EMap _ e1 e2 + | Occ Zero Zero <- occCount IZ e1 + , STArr n _ <- typeOf e2 -> + acted $ simplify' $ + EBuild ext n (EShape ext e2) $ + subst (\_ t' -> \case IZ -> error "Unused variable was used" + IS i -> EVar ext t' (IS i)) + e1 + + -- vertical fusion + EMap _ e1 (EMap _ e2 e3) -> + acted $ simplify' $ + EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3 + -- projection down-commuting EFst _ (ECase _ e1 e2 e3) -> acted $ simplify' $ @@ -200,11 +215,15 @@ simplify'Rec = \case EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 -- TODO: more array indexing - EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) - EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 + EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 + EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1 + EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) + EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 -- TODO: more array shape - EShape _ (EBuild _ _ e _) -> acted $ simplify' e + EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 + EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2) + EShape _ (EReplicate1Inner _ en earr) -> acted $ simplify' (EPair ext (EShape ext earr) en) -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) @@ -307,12 +326,17 @@ simplify'Rec = \case ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> [simprec| EBuild ext n *a *b |] + EMap _ a b -> [simprec| EMap ext *a *b |] EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] EUnit _ e -> [simprec| EUnit ext *e |] EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] + EReshape _ n a b -> [simprec| EReshape ext n *a *b |] + EZip _ a b -> [simprec| EZip ext *a *b |] + EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] + EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> [simprec| EIdx0 ext *e |] EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] @@ -331,21 +355,13 @@ simplify'Rec = \case e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) pure (EWith ext t e1' e2') + -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |] + -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |] EZero _ t e -> [simprec| EZero ext t *e |] EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s -cheapExpr :: Expr x env t -> Bool -cheapExpr = \case - EVar{} -> True - ENil{} -> True - EConst{} -> True - EFst _ e -> cheapExpr e - ESnd _ e -> cheapExpr e - EUnit _ e -> cheapExpr e - _ -> False - -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. hasAdds :: Expr x env t -> Bool @@ -368,12 +384,17 @@ hasAdds = \case ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b + EMap _ a b -> hasAdds a || hasAdds b EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e + EReshape _ _ a b -> hasAdds a || hasAdds b + EZip _ a b -> hasAdds a || hasAdds b + EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e diff --git a/src/Simplify/TH.hs b/src/CHAD/Simplify/TH.hs index 2e0076a..4af5394 100644 --- a/src/Simplify/TH.hs +++ b/src/CHAD/Simplify/TH.hs @@ -1,9 +1,9 @@ {-# LANGUAGE TemplateHaskellQuotes #-} -module Simplify.TH (simprec) where +module CHAD.Simplify.TH (simprec) where import Data.Bifunctor (first) import Data.Char -import Data.List (foldl1') +import Data.List (foldl', foldl1') import Language.Haskell.TH import Language.Haskell.TH.Quote import Text.ParserCombinators.ReadP diff --git a/src/Util/IdGen.hs b/src/CHAD/Util/IdGen.hs index 3f6611d..d4fd945 100644 --- a/src/Util/IdGen.hs +++ b/src/CHAD/Util/IdGen.hs @@ -1,6 +1,6 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -module Util.IdGen where +module CHAD.Util.IdGen where import Control.Monad.Fix import Control.Monad.Trans.State.Strict diff --git a/src/Language.hs b/src/Language.hs deleted file mode 100644 index 4e6d604..0000000 --- a/src/Language.hs +++ /dev/null @@ -1,233 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitForAll #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} -module Language ( - fromNamed, - NExpr, - Ex, - module Language, - module AST.Types, - module Data, - Lookup, -) where - -import Array -import AST -import AST.Sparse.Types -import AST.Types -import CHAD.Types -import Data -import Language.AST - - -data a :-> b = a :-> b - deriving (Show) -infixr 0 :-> - - -body :: NExpr env t -> NFun env env t -body = NBody - -lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t -lambda = NLam - -inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t -inline = inlineNFun - --- To be used to construct the argument list for 'inline'. --- --- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y --- > in inline fun (SNil .$ 16 .$ 26) -(.$) :: SList f list -> f a -> SList f (a : list) -(.$) = flip SCons - - -let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -let_ = NELet - -pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) -pair = NEPair - -fst_ :: NExpr env (TPair a b) -> NExpr env a -fst_ = NEFst - -snd_ :: NExpr env (TPair a b) -> NExpr env b -snd_ = NESnd - -nil :: NExpr env TNil -nil = NENil - -inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) -inl = NEInl knownTy - -inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) -inr = NEInr knownTy - -case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c -case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 - -nothing :: KnownTy a => NExpr env (TMaybe a) -nothing = NENothing knownTy - -just :: NExpr env a -> NExpr env (TMaybe a) -just = NEJust - -maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b -maybe_ a (v :-> b) c = NEMaybe a v b c - -constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) -constArr_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConstArr knownNat ty x - -build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) -build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) - -build2 :: NExpr env TIx -> NExpr env TIx - -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) - -> NExpr env (TArr (S (S Z)) t) -build2 a1 a2 (v1 :-> v2 :-> b) = - NEBuild (SS (SS SZ)) - (pair (pair nil a1) a2) - #idx - (let_ v1 (snd_ (fst_ #idx)) $ - let_ v2 (NEDrop SZ (snd_ #idx)) $ - NEDrop (SS (SS SZ)) b) - -build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) -build n a (v :-> b) = NEBuild n a v b - -map_ :: forall n a b env name. (KnownNat n, KnownTy a) - => (Var name a :-> NExpr ('(name, a) : env) b) - -> NExpr env (TArr n a) -> NExpr env (TArr n b) -map_ (v :-> a) b - | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) = - let_ #arg b $ - build knownNat (shape #arg) $ #i :-> - let_ v (#arg ! #i) $ - NEDrop (SS SZ) (NEDrop (SS SZ) a) - -fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 - -sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -sum1i e = NESum1Inner e - -unit :: NExpr env t -> NExpr env (TArr Z t) -unit = NEUnit - -replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) -replicate1i n a = NEReplicate1Inner n a - -maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -maximum1i e = NEMaximum1Inner e - -minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -minimum1i e = NEMinimum1Inner e - -const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) -const_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty x - -idx0 :: NExpr env (TArr Z t) -> NExpr env t -idx0 = NEIdx0 - --- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) --- (.!) = NEIdx1 --- infixl 9 .! - -(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t -(!) = NEIdx -infixl 9 ! - -shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -shape = NEShape - -length_ :: NExpr env (TArr N1 t) -> NExpr env TIx -length_ e = snd_ (shape e) - -oper :: SOp a t -> NExpr env a -> NExpr env t -oper = NEOp - -oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t -oper2 op a b = NEOp op (pair a b) - -error_ :: KnownTy t => String -> NExpr env t -error_ s = NEError knownTy s - -custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) - -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) - -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) - -> NExpr env a -> NExpr env b - -> NExpr env t -custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = - NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 - -recompute :: NExpr env a -> NExpr env a -recompute = NERecompute - -with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) -with a (n :-> b) = NEWith (knownMTy @t) a n b - -accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c - -accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -accumS p a sp b c = NEAccum knownMTy p a sp b c - - -(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .== b = oper (OEq knownScalTy) (pair a b) -infix 4 .== - -(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .< b = oper (OLt knownScalTy) (pair a b) -infix 4 .< - -(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>) = flip (.<) -infix 4 .> - -(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .<= b = oper (OLe knownScalTy) (pair a b) -infix 4 .<= - -(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>=) = flip (.<=) -infix 4 .>= - -not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -not_ = oper ONot - -and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -and_ = oper2 OAnd -infixr 3 `and_` - -or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -or_ = oper2 OOr -infixr 2 `or_` - -mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) -mod_ = oper2 (OMod knownScalTy) -infixl 7 `mod_` - --- | The first alternative is the True case; the second is the False case. -if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t -if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) - -round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) -round_ = oper ORound64 - -toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) -toFloat_ = oper OToFl64 - -idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) -idiv = oper2 (OIDiv knownScalTy) -infixl 7 `idiv` diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs index 1b2b7d7..5ca0f38 100644 --- a/test-framework/Test/Framework.hs +++ b/test-framework/Test/Framework.hs @@ -1,63 +1,109 @@ +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingVia #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module Test.Framework ( TestTree, testGroup, - testGroupCollapse, + groupSetCollapse, testProperty, - withResource, - withResource', runTests, defaultMain, Options(..), + -- * Resources + withResource, + withResource', + TestCtx, + outputWarningText, + -- * Compatibility TestName, ) where -import Control.Monad (forM, when) -import Control.Monad.Trans.State.Strict +import Control.Concurrent (setNumCapabilities, forkIO, forkOn, killThread) +import Control.Concurrent.MVar +import Control.Concurrent.STM +import Control.Exception (SomeException, throw, try, throwIO) +import Control.Monad (forM, when, forM_) import Control.Monad.IO.Class +import Data.IORef import Data.List (isInfixOf, intercalate) -import Data.Maybe (isJust, mapMaybe, fromJust) +import Data.Maybe (mapMaybe, fromJust) +import Data.Monoid (All(..), Any(..), Sum(..)) +import Data.PQueue.Prio.Min qualified as PQ import Data.String (fromString) import Data.Time.Clock -import System.Environment +import GHC.Conc (getNumProcessors) +import GHC.Generics (Generic, Generically(..)) +import System.Console.ANSI qualified as ANSI +import System.Console.Concurrent (outputConcurrent) +import System.Console.Regions +import System.Environment (getArgs, getProgName) import System.Exit import System.IO (hFlush, hPutStrLn, stdout, stderr, hIsTerminalDevice) import Text.Read (readMaybe) -import qualified Hedgehog as H -import qualified Hedgehog.Internal.Config as H -import qualified Hedgehog.Internal.Property as H -import qualified Hedgehog.Internal.Report as H -import qualified Hedgehog.Internal.Runner as H -import qualified Hedgehog.Internal.Seed as H.Seed +import Hedgehog qualified as H +import Hedgehog.Internal.Config qualified as H +import Hedgehog.Internal.Property qualified as H +import Hedgehog.Internal.Report qualified as H +import Hedgehog.Internal.Runner qualified as H +import Hedgehog.Internal.Seed qualified as H.Seed + +-- TODO: with GHC 9.12 we have tryWithContext and rethrowIO, which is better for rethrowing exceptions + + +type TestName = String data TestTree - = Group Bool String [TestTree] - | forall a. Resource (IO a) (a -> IO ()) (a -> TestTree) + = Group GroupOpts String [TestTree] + | forall a. Resource String ((?testCtx :: TestCtx) => IO a) ((?testCtx :: TestCtx) => a -> IO ()) (a -> TestTree) + -- ^ Name is not specified by user, but inherited from the tree below | HP String H.Property -type TestName = String +data TestCtx = TestCtx + { tctxOutput :: String -> IO () } + +outputWarningText :: (?testCtx :: TestCtx) => String -> IO () +outputWarningText = tctxOutput ?testCtx + +-- Not exported because a Resource is not supposed to have a name in the first place +treeName :: TestTree -> String +treeName (Group _ name _) = name +treeName (Resource name _ _ _) = name +treeName (HP name _) = name + +data GroupOpts = GroupOpts + { goCollapse :: Bool } + deriving (Show) + +defaultGroupOpts :: GroupOpts +defaultGroupOpts = GroupOpts False testGroup :: String -> [TestTree] -> TestTree -testGroup = Group False +testGroup = Group defaultGroupOpts -testGroupCollapse :: String -> [TestTree] -> TestTree -testGroupCollapse = Group True +groupSetCollapse :: TestTree -> TestTree +groupSetCollapse (Group opts name trees) = Group opts { goCollapse = True } name trees +groupSetCollapse _ = error "groupSetCollapse: not called on a Group" --- | The @a -> TestTree@ function must use the @a@ only inside properties: when --- not actually running properties, it will be passed 'undefined'. -withResource :: IO a -> (a -> IO ()) -> (a -> TestTree) -> TestTree -withResource = Resource +-- | The @a -> TestTree@ function must use the @a@ only inside properties: the +-- function will be passed 'undefined' when exploring the test tree (without +-- running properties). +withResource :: ((?testCtx :: TestCtx) => IO a) -> ((?testCtx :: TestCtx) => a -> IO ()) -> (a -> TestTree) -> TestTree +withResource make cleanup fun = Resource (treeName (fun undefined)) make cleanup fun -- | Same caveats as 'withResource'. -withResource' :: IO a -> (a -> TestTree) -> TestTree +withResource' :: ((?testCtx :: TestCtx) => IO a) -> (a -> TestTree) -> TestTree withResource' make fun = withResource make (\_ -> return ()) fun testProperty :: String -> H.Property -> TestTree @@ -66,27 +112,35 @@ testProperty = HP filterTree :: Options -> TestTree -> Maybe TestTree filterTree (Options { optsPattern = pat }) = go [] where - go path (Group collapse name trees) = + go path (Group opts name trees) = case mapMaybe (go (name:path)) trees of [] -> Nothing - trees' -> Just (Group collapse name trees') - go path (Resource make free fun) = + trees' -> Just (Group opts name trees') + go path (Resource inhname make free fun) = case go path (fun undefined) of Nothing -> Nothing - Just _ -> Just $ Resource make free (fromJust . go path . fun) + Just _ -> Just $ Resource inhname make free (fromJust . go path . fun) go path hp@(HP name _) | pat `isInfixOf` renderPath (name:path) = Just hp | otherwise = Nothing renderPath comps = "^" ++ intercalate "/" (reverse comps) ++ "$" +treeNumTests :: TestTree -> Int +treeNumTests (Group _ _ ts) = sum (map treeNumTests ts) +treeNumTests (Resource _ _ _ fun) = treeNumTests (fun undefined) +treeNumTests HP{} = 1 + computeMaxLen :: TestTree -> Int computeMaxLen = go 0 where go :: Int -> TestTree -> Int - go indent (Group True name trees) = maximum (2*indent + length name : map (go (indent+1)) trees) - go indent (Group False _ trees) = maximum (0 : map (go (indent+1)) trees) - go indent (Resource _ _ fun) = go indent (fun undefined) + go indent (Group opts name trees) + -- If we collapse, the name of the group gets prefixed before the final status message after collapsing. + | goCollapse opts = maximum (2*indent + length name : map (go (indent+1)) trees) + -- If we don't collapse, the group name does get printed but without any status message, so it doesn't need to get accounted for in maxlen. + | otherwise = maximum (0 : map (go (indent+1)) trees) + go indent (Resource _ _ _ fun) = go indent (fun undefined) go indent (HP name _) = 2 * indent + length name data Stats = Stats @@ -97,22 +151,21 @@ data Stats = Stats initStats :: Stats initStats = Stats 0 0 -newtype M a = M (StateT Stats IO a) - deriving newtype (Functor, Applicative, Monad, MonadIO) - -modifyStats :: (Stats -> Stats) -> M () -modifyStats f = M (modify f) +modifyStats :: (?stats :: IORef Stats) => (Stats -> Stats) -> IO () +modifyStats f = atomicModifyIORef' ?stats (\s -> (f s, ())) data Options = Options { optsPattern :: String , optsHelp :: Bool , optsHedgehogReplay :: Maybe (H.Skip, H.Seed) , optsHedgehogShrinks :: Maybe Int + , optsParallel :: Bool + , optsVerbose :: Bool } deriving (Show) defaultOptions :: Options -defaultOptions = Options "" False Nothing Nothing +defaultOptions = Options "" False Nothing Nothing False False parseOptions :: [String] -> Options -> Either String Options parseOptions [] opts = pure opts @@ -134,6 +187,8 @@ parseOptions ("--hedgehog-shrinks":arg:args) opts = case readMaybe arg of Just n -> parseOptions args opts { optsHedgehogShrinks = Just n } Nothing -> Left "Invalid argument to '--hedgehog-shrinks'" +parseOptions ("--parallel":args) opts = parseOptions args opts { optsParallel = True } +parseOptions ("--verbose":args) opts = parseOptions args opts { optsVerbose = True } parseOptions (arg:_) _ = Left $ "Unrecognised argument: '" ++ arg ++ "'" printUsage :: IO () @@ -147,7 +202,12 @@ printUsage = do ," test looks like: '^group1/group2/testname$'." ," --hedgehog-replay '{skip} {seed}'" ," Skip to a particular generated Hedgehog test. Should be used" - ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set."] + ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set." + ," --hedgehog-shrinks NUM" + ," Limit the number of shrinking steps." + ," --parallel Run tests in parallel." + ," --verbose Currently only has an effect with --parallel. Also shows OK" + ," and timing for test groups, not only individual tests."] defaultMain :: TestTree -> IO () defaultMain tree = do @@ -165,58 +225,210 @@ runTests options = \tree' -> return (ExitFailure 1) Just tree -> do isterm <- hIsTerminalDevice stdout - let M m = let ?maxlen = computeMaxLen tree - ?istty = isterm - in go 0 id tree starttm <- getCurrentTime - (success, stats) <- runStateT m initStats + statsRef <- newIORef initStats + success <- + let ?stats = statsRef + ?options = options + ?maxlen = computeMaxLen tree + ?istty = isterm + in if optsParallel options + then do nproc <- getNumProcessors + setNumCapabilities nproc + displayConsoleRegions $ + withWorkerPool nproc $ \pool -> do + let ?pool = pool + successVar <- newEmptyMVar + runTreePar Nothing [] [] tree successVar + readMVar successVar + else getAll . seqresAllSuccess <$> runTreeSeq 0 [] tree + stats <- readIORef statsRef endtm <- getCurrentTime - let ?istty = isterm in printStats stats (diffUTCTime endtm starttm) - return (if isJust success then ExitSuccess else ExitFailure 1) + let ?istty = isterm in printStats (treeNumTests tree) stats (diffUTCTime endtm starttm) + return (if success then ExitSuccess else ExitFailure 1) + +-- | Returns when all jobs in this tree have been scheduled. When all jobs are +-- done, the outvar is filled with whether all tests in this tree were +-- successful. +-- The mparregion is the parent region to take over, if any. Having a parent +-- region to take over implies that we are currently executing in a worker and +-- can hence run blocking user code directly in the current thread. +runTreePar :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> [Int] -> [String] -> TestTree -> MVar Bool -> IO () +-- TODO: handle collapse somehow? +runTreePar mparregion revidxlist revpath (Group _ name trees) outvar = do + let path = intercalate "/" (reverse (name : revpath)) + -- outputConcurrent $ "! " ++ path ++ ": Started\n" + + -- If we have exactly one child and we're currently running in a worker, we can continue doing so + mparregion2 <- case trees of + [_] -> return mparregion + _ -> do -- If not, we have to close the region (and implicitly relinquish the worker job) + forM_ mparregion closeConsoleRegion + return Nothing + + starttm <- getCurrentTime + suboutvars <- forM (zip trees [0..]) $ \(tree, idx) -> do + suboutvar <- newEmptyMVar + runTreePar mparregion2 (idx : revidxlist) (name : revpath) tree suboutvar + return suboutvar + + -- outputConcurrent $ "! " ++ path ++ ": Scheduled all\n" + + -- If we took over the parent region then this readMVar will resolve + -- immediately and the forkIO would be unnecessary. Meh. + _ <- forkIO $ do + success <- and <$> traverse readMVar suboutvars + endtm <- getCurrentTime + -- outputConcurrent $ "! " ++ path ++ ": Done\n" + if success && optsVerbose ?options + then let text = path ++ ": " ++ ansiGreen ++ "OK" ++ ansiReset ++ " " ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + in outputConcurrent (text ++ "\n") + else return () + putMVar outvar success + + return () + +runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runResource topmparregion 1 toptree topoutvar where - -- If all tests are successful, returns the number of output lines produced - go :: (?maxlen :: Int, ?istty :: Bool) => Int -> (String -> String) -> TestTree -> M (Maybe Int) - go indent path (Group collapse name trees) = do - liftIO $ putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout - starttm <- liftIO getCurrentTime - mlns <- fmap (fmap sum . sequence) . forM trees $ - go (indent + 1) (path . (name++) . ('/':)) - endtm <- liftIO getCurrentTime - case mlns of - Just lns | collapse, ?istty -> do - let thislen = 2*indent + length name - liftIO $ putStrLn $ concat (replicate (lns+1) "\x1B[A\x1B[2K") ++ "\x1B[G" ++ - replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ - "\x1B[32mOK\x1B[0m" ++ - prettyDuration False (realToFrac (diffUTCTime endtm starttm)) - return (Just 1) - _ -> return ((+1) <$> mlns) - go indent path (Resource make cleanup fun) = do - value <- liftIO make - success <- go indent path (fun value) - liftIO $ cleanup value - return success - go indent path (HP name (H.Property config test)) = do + runResource + :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool, ?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> Int -> TestTree -> MVar Bool -> IO () + runResource mparregion depth (Resource inhname make cleanup fun) outvar = do + let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")" + path = intercalate "/" (reverse (pathitem : revpath)) + idxlist = reverse revidxlist + let ?testCtx = TestCtx (\str -> + outputConcurrent (ansiYellow ++ "## Warning for " ++ path ++ ":" ++ ansiReset ++ + "\n" ++ str ++ "\n")) + submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do + setConsoleRegion makeRegion ('|' : path ++ " [R] making...") + + try make >>= \case + Left (err :: SomeException) -> do + finishConsoleRegion makeRegion $ + ansiRed ++ "Exception building resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err + putMVar outvar False + Right value -> do + suboutvar <- newEmptyMVar + runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion + + _ <- forkIO $ do + success <- readMVar suboutvar + poolSubmit ?pool idxlist Nothing $ do + cleanupRegion <- openConsoleRegion Linear + setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...") + try (cleanup value) >>= \case + Left (err :: SomeException) -> do + finishConsoleRegion cleanupRegion $ + ansiRed ++ "Exception cleaning up resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err + putMVar outvar False + Right () -> do + closeConsoleRegion cleanupRegion + putMVar outvar success + return () + runResource mparregion _ tree outvar = runTreePar mparregion revidxlist revpath tree outvar + +runTreePar mparregion revidxlist revpath (HP name prop) outvar = do + let path = intercalate "/" (reverse (name : revpath)) + idxlist = reverse revidxlist + + submitOrRunIn mparregion idxlist (Just outvar) $ \region -> do + -- outputConcurrent $ "! " ++ path ++ ": Running" + let prefix = path ++ " [T]" + setConsoleRegion region ('|' : prefix) + let progressHandler report = do + str <- H.renderProgress H.EnableColor (Just (fromString "")) report + setConsoleRegion region ('|' : prefix ++ ": " ++ replace '\n' " " str) + (ok, rendered) <- runHP progressHandler revpath name prop + finishConsoleRegion region (path ++ ": " ++ rendered) + return ok + +submitOrRunIn :: (?pool :: WorkerPool [Int]) + => Maybe ConsoleRegion -> [Int] -> Maybe (MVar a) -> (ConsoleRegion -> IO a) -> IO () +submitOrRunIn Nothing idxlist outvar fun = + poolSubmit ?pool idxlist outvar (openConsoleRegion Linear >>= fun) +submitOrRunIn (Just reg) _idxlist outvar fun = do + result <- fun reg + forM_ outvar $ \mvar -> putMVar mvar result + +data SeqRes = SeqRes + { seqresHaveWarnings :: Any + , seqresAllSuccess :: All + , seqresNumLines :: Sum Int } + deriving (Generic) + deriving (Semigroup, Monoid) via Generically SeqRes + +-- | If all tests are successful, returns the number of output lines produced +runTreeSeq :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool) + => Int -> [String] -> TestTree -> IO SeqRes +runTreeSeq indent revpath (Group opts name trees) = do + putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout + starttm <- getCurrentTime + res <- fmap mconcat . forM trees $ + runTreeSeq (indent + 1) (name : revpath) + endtm <- getCurrentTime + if not (getAny (seqresHaveWarnings res)) && getAll (seqresAllSuccess res) && goCollapse opts && ?istty + then do let thislen = 2*indent + length name - let outputPrefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' - when ?istty $ liftIO $ putStr outputPrefix >> hFlush stdout + let Sum lns = seqresNumLines res + putStrLn $ concat (replicate (lns+1) (ANSI.cursorUpCode 1 ++ ANSI.clearLineCode)) ++ + ANSI.setCursorColumnCode 0 ++ + replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ + ansiGreen ++ "OK" ++ ansiReset ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + return (mempty { seqresNumLines = 1 }) + else return (res <> mempty { seqresNumLines = 1 }) +runTreeSeq indent revpath (Resource _ make cleanup fun) = do + let path = intercalate "/" (reverse revpath) + outputted <- newIORef False + let ?testCtx = TestCtx (\str -> do + atomicModifyIORef' outputted (\_ -> (True, ())) + putStrLn (ansiYellow ++ "## Warning for " ++ path ++ + ":" ++ ansiReset ++ "\n" ++ str)) + res <- try make >>= \case + Left (err :: SomeException) -> do + putStrLn $ ansiRed ++ "Exception building resource at " ++ path ++ ":" ++ ansiReset + print err + return (mempty { seqresAllSuccess = All False }) + Right value -> do + res <- runTreeSeq indent revpath (fun value) + try (cleanup value) >>= \case + Left (err :: SomeException) -> do + putStrLn $ ansiRed ++ "Exception cleaning up resource at " ++ path ++ ":" ++ ansiReset + print err + return (res { seqresAllSuccess = All False }) + Right () -> return res - let (config', seedfun) = applyHedgehogOptions options config - seed <- seedfun + warnings <- readIORef outputted + return (res <> mempty { seqresHaveWarnings = Any warnings }) +runTreeSeq indent path (HP name prop) = do + let thislen = 2*indent + length name + let prefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' + when ?istty $ putStr prefix >> hFlush stdout + (ok, rendered) <- runHP (outputProgress (?maxlen + 2)) path name prop + putStrLn ((if ?istty then ANSI.clearFromCursorToLineEndCode else prefix) ++ rendered) >> hFlush stdout + return (mempty { seqresAllSuccess = All ok, seqresNumLines = 1 }) - starttm <- liftIO getCurrentTime - report <- liftIO $ H.checkReport config' 0 seed test (outputProgress (?maxlen + 2)) - endtm <- liftIO getCurrentTime +runHP :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int) + => (H.Report H.Progress -> IO ()) + -> [String] + -> String -> H.Property -> IO (Bool, String) +runHP progressPrinter revpath name (H.Property config test) = do + let (config', seedfun) = applyHedgehogOptions ?options config + seed <- seedfun - liftIO $ do - when (not ?istty) $ putStr outputPrefix - printResult report (path name) (diffUTCTime endtm starttm) - hFlush stdout + starttm <- getCurrentTime + report <- H.checkReport config' 0 seed test progressPrinter + endtm <- getCurrentTime - let ok = H.reportStatus report == H.OK - modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats - , statsTotal = 1 + statsTotal stats } - return (if ok then Just 1 else Nothing) + rendered <- renderResult report (intercalate "/" (reverse (name : revpath))) (diffUTCTime endtm starttm) + + let ok = H.reportStatus report == H.OK + modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats + , statsTotal = 1 + statsTotal stats } + return (ok, rendered) applyHedgehogOptions :: MonadIO m => Options -> H.PropertyConfig -> (H.PropertyConfig, m H.Seed) applyHedgehogOptions opts config0 = @@ -232,32 +444,77 @@ outputProgress :: (?istty :: Bool) => Int -> H.Report H.Progress -> IO () outputProgress indent report | ?istty = do str <- H.renderProgress H.EnableColor (Just (fromString "")) report - putStr (replace '\n' " " str ++ "\x1B[" ++ show (indent+1) ++ "G") + putStr (replace '\n' " " str ++ ANSI.setCursorColumnCode indent) hFlush stdout | otherwise = return () -printResult :: (?istty :: Bool) => H.Report H.Result -> String -> NominalDiffTime -> IO () -printResult report path timeTaken = do +renderResult :: H.Report H.Result -> String -> NominalDiffTime -> IO String +renderResult report path timeTaken = do str <- H.renderResult H.EnableColor (Just (fromString "")) report case H.reportStatus report of - H.OK -> putStrLn (ansi "\x1B[K" ++ str ++ prettyDuration False (realToFrac timeTaken)) + H.OK -> return (str ++ prettyDuration False (realToFrac timeTaken)) H.Failed failure -> do let H.Report { H.reportTests = count, H.reportDiscards = discards } = report replayInfo = H.skipCompress (H.SkipToShrink count discards (H.failureShrinkPath failure)) ++ " " ++ show (H.reportSeed report) suffix = "\n Flags to reproduce: `-p '" ++ path ++ "' --hedgehog-replay '" ++ replayInfo ++ "'`" - putStrLn (ansi "\x1B[K" ++ str ++ suffix) - _ -> putStrLn (ansi "\x1B[K" ++ str) + return (str ++ suffix) + _ -> return str -printStats :: (?istty :: Bool) => Stats -> NominalDiffTime -> IO () -printStats stats timeTaken - | statsOK stats == statsTotal stats = do - putStrLn $ ansi "\x1B[32m" ++ "All " ++ show (statsTotal stats) ++ - " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m" +printStats :: (?istty :: Bool) => Int -> Stats -> NominalDiffTime -> IO () +printStats numTests stats timeTaken + | statsOK stats == numTests = do + putStrLn $ ansiGreen ++ "All " ++ show (statsTotal stats) ++ + " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset + | statsOK stats == statsTotal stats = + putStrLn $ ansiRed ++ "Failed (" ++ show (numTests - statsTotal stats) ++ " tests could not run)." ++ + prettyDuration True (realToFrac timeTaken) ++ ansiReset | otherwise = let nfailed = statsTotal stats - statsOK stats - in putStrLn $ ansi "\x1B[31m" ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ - " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m" + in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ " tests" ++ + (if statsTotal stats /= numTests then " (" ++ show (numTests - statsTotal stats) ++ " could not run)" else "") ++ + "." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset + + +newtype WorkerPool k = WorkerPool (TVar (PQ.MinPQueue k (Terminate PoolJob))) +data PoolJob = forall a. PoolJob (IO a) (Maybe (MVar a)) +data Terminate a = Terminate | Value a + deriving (Eq, Ord) -- Terminate sorts before Value + +withWorkerPool :: Ord k => Int -> (WorkerPool k -> IO a) -> IO a +withWorkerPool numWorkers k = do + chan <- newTVarIO PQ.empty + threads <- forM [0..numWorkers-1] (\i -> forkOn i (worker i chan)) + eres <- try (k (WorkerPool chan)) + case eres of + Left (err :: SomeException) -> do + atomically $ writeTVar chan PQ.empty + forM_ threads killThread + throw err + Right res -> do + readTVarIO chan >>= \case + PQ.Empty -> return () + _ -> throwIO (userError "withWorkerPool: computation exited before all jobs were handled") + return res + where + worker :: Ord k => Int -> TVar (PQ.MinPQueue k (Terminate PoolJob)) -> IO () + worker idx chan = do + job <- atomically $ + readTVar chan >>= \case + PQ.Empty -> retry + (_, j) PQ.:< q -> writeTVar chan q >> return j + case job of + Value (PoolJob action mmvar) -> do + -- outputConcurrent $ "[" ++ show idx ++ "] got job\n" + result <- action + forM_ mmvar $ \mvar -> putMVar mvar result + worker idx chan + Terminate -> return () + +poolSubmit :: Ord k => WorkerPool k -> k -> Maybe (MVar a) -> IO a -> IO () +poolSubmit (WorkerPool chan) key mmvar action = + atomically $ modifyTVar chan $ PQ.insert key (Value (PoolJob action mmvar)) + prettyDuration :: Bool -> Double -> String prettyDuration False x | x < 0.5 = "" @@ -273,3 +530,24 @@ replace x ys = concatMap (\y -> if y == x then ys else [y]) ansi :: (?istty :: Bool) => String -> String ansi | ?istty = id | otherwise = const "" + +ansiRed, ansiYellow, ansiGreen, ansiReset :: (?istty :: Bool) => String +ansiRed = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red]) +ansiYellow = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Vivid ANSI.Yellow]) +ansiGreen = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green]) +ansiReset = ansi (ANSI.setSGRCode [ANSI.Reset]) + +-- getTermIsDark :: IO Bool +-- getTermIsDark = do +-- mclr <- ANSI.getLayerColor ANSI.Background +-- case mclr of +-- Nothing -> return True +-- Just (RGB r g b) -> +-- let cvt n = fromIntegral n / fromIntegral (maxBound `asTypeOf` n) +-- in return $ (cvt r + cvt g + cvt b) / 3 < (0.5 :: Double) + +-- ansiRegionBg :: (?istty :: Bool, ?termisdark :: Bool) => String +-- ansiRegionBg +-- | not ?istty = "" +-- | ?termisdark = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.0 0.05 0.1)] +-- | otherwise = ANSI.setSGRCode [ANSI.SetRGBColor ANSI.Background (rgb 0.95 0.95 1.0)] diff --git a/test/Main.hs b/test/Main.hs index 0a57cbf..05597cc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} @@ -11,35 +12,38 @@ {-# LANGUAGE UndecidableInstances #-} module Main where +import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) -import qualified Data.Map.Strict as Map -import qualified Data.Text as T +import Data.Map.Strict qualified as Map +import Data.Text qualified as T import Hedgehog -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range import Test.Framework -import Array -import AST hiding ((.>)) -import AST.Pretty -import AST.UnMonoid -import CHAD.Top -import CHAD.Types -import CHAD.Types.ToTan -import Compile -import qualified Example -import qualified Example.GMM as Example -import Example.Types -import ForwardAD -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep -import Language -import Simplify +import CHAD.Array +import CHAD.AST hiding ((.>)) +import CHAD.AST.Count (pruneExpr) +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Drev.Types.ToTan +import CHAD.Example qualified as Example +import CHAD.Example.GMM qualified as Example +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep +import CHAD.Language +import CHAD.Simplify data TypedValue t = TypedValue (STy t) (Rep t) @@ -63,18 +67,18 @@ simplifyIters iters env | Dict <- envKnown env = SimplIters n -> simplifyN n SimplFix -> simplifyFix --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) -gradientByCHAD simplIters env term input = - let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - (out, grad) = interpretOpen False env input dterm - in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) +-- gradientByCHAD simplIters env term input = +-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term +-- (out, grad) = interpretOpen False env input dterm +-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) -gradientByCHAD' simplIters env term input = - second (second (toTanE env input)) $ - gradientByCHAD simplIters env term input +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) +-- gradientByCHAD' simplIters env term input = +-- second (second (toTanE env input)) $ +-- gradientByCHAD simplIters env term input gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 @@ -92,8 +96,8 @@ extendDN (STMaybe _) Nothing = pure Nothing extendDN (STMaybe t) (Just x) = Just <$> extendDN t x extendDN (STArr _ t) arr = traverse (extendDN t) arr extendDN (STScal sty) x = case sty of - STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) - STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) + STF32 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) + STF64 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) STI32 -> pure x STI64 -> pure x STBool -> pure x @@ -217,8 +221,8 @@ genShape = \n tpl -> do shapeDiv :: Shape n -> DimNames n -> Int -> Shape n shapeDiv ShNil _ _ = ShNil - shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) - shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f)) + shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` max lo (n `div` f) + shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` max lo (n `div` f) shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f) shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f) @@ -240,8 +244,8 @@ genValue topty tpl = case topty of ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of - STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + STF32 -> Value <$> Gen.realFloat (Range.constant (-10) 10) + STF64 -> Value <$> Gen.realFloat (Range.constant (-10) 10) STI32 -> genInt STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] @@ -302,7 +306,7 @@ adTestGen name expr envGenerator = exprS = simplifyFix expr in withCompiled env expr $ \primalfun -> withCompiled env (simplifyFix expr) $ \primalSfun -> - testGroupCollapse name + groupSetCollapse $ testGroup name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS ,testGroup "chad" @@ -349,17 +353,23 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS + dtermSChadSUSP = simplifyFix $ pruneExpr env dtermSChadSUS in - withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env dtermSChadSUS $ \dcompSChadSUS -> + withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS + when (not (null output)) $ + outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>" + return fun) $ \fwdartifactC -> + withCompiled env dtermSChadSUSP $ \dcompSChadSUSP -> testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) - diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0))) + let dtermChad20 = simplifyN 20 dtermChad0 + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20)) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20))) + let dtermSChad20 = simplifyN 20 dtermSChad0 + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20)) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input @@ -369,46 +379,54 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS - (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS - (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 - (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS - tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 - tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS - tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS - tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 - tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS - tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP - (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input) - let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS + (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input) + let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP -- annotate (showEnv (d2e env) gradChad0) -- annotate (showEnv (d2e env) gradChadS) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outChadSUS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outSChadSUS closeIsh outPrimal - diff outCompSChadSUS closeIsh outPrimal + annotate (ppExpr env dtermSChadSUSP) + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outSChadSUSP closeIsh outPrimal + diff outCompSChadSUSP closeIsh outPrimal let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) - diff tansChad closeIshE' tansFwd - diff tansChadS closeIshE' tansFwd - diff tansChadSUS closeIshE' tansFwd - diff tansSChad closeIshE' tansFwd - diff tansSChadS closeIshE' tansFwd - diff tansSChadSUS closeIshE' tansFwd - diff tansCompSChadSUS closeIshE' tansFwd + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansSChadSUSP closeIshE' tansFwd + diff tansCompSChadSUSP closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) +withCompiled env expr = withResource' $ do + (fun, output) <- compile env expr + when (not (null output)) $ + outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]) gen_gmm = do @@ -423,7 +441,7 @@ gen_gmm = do vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) - vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vgamma <- Gen.realFloat (Range.constant (-10) 10) vm <- Gen.integral (Range.linear 0 5) let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma @@ -554,6 +572,13 @@ tests_Compile = testGroup "Compile" nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil + + ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $ + fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a + + ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $ + let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $ + fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d ] tests_AD :: TestTree @@ -659,6 +684,38 @@ tests_AD = testGroup "AD" ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm + + ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree + + ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #n (snd_ #sh * snd_ (fst_ #sh)) $ + idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a + + ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #innern (snd_ #sh) $ + let_ #n (#innern * snd_ (fst_ #sh)) $ + let_ #flata (reshape (SS SZ) (pair nil #n) #a) $ + -- ensure the input array to EReshape is shared + idx0 $ sum1i $ + build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern)) + + ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a + + ,adTest "fold-prod" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y) 1 #a + + ,adTest "fold-freevar" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + let_ #v 2 $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #v) 1 #a + + ,adTestTp "fold-freearr" (C "" 1) $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #a ! pair nil 0) 1 #a + + ,adTest "map" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ sum1i $ map_ (#x :-> 2 * #x) #a ] main :: IO () |
