diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
| -rw-r--r-- | bench/Main.hs | 28 | ||||
| -rw-r--r-- | chad-fast.cabal | 80 | ||||
| -rw-r--r-- | example/Main.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/AST.hs (renamed from src/AST.hs) | 16 | ||||
| -rw-r--r-- | src/CHAD/AST/Accum.hs (renamed from src/AST/Accum.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/AST/Bindings.hs (renamed from src/AST/Bindings.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/AST/Count.hs (renamed from src/AST/Count.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/AST/Env.hs (renamed from src/AST/Env.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs (renamed from src/AST/Pretty.hs) | 12 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse.hs (renamed from src/AST/Sparse.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs (renamed from src/AST/Sparse/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/AST/SplitLets.hs (renamed from src/AST/SplitLets.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs (renamed from src/AST/Types.hs) | 4 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs (renamed from src/AST/UnMonoid.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken.hs (renamed from src/AST/Weaken.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken/Auto.hs (renamed from src/AST/Weaken/Auto.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Analysis/Identity.hs (renamed from src/Analysis/Identity.hs) | 12 | ||||
| -rw-r--r-- | src/CHAD/Array.hs (renamed from src/Array.hs) | 4 | ||||
| -rw-r--r-- | src/CHAD/Compile.hs (renamed from src/Compile.hs) | 18 | ||||
| -rw-r--r-- | src/CHAD/Compile/Exec.hs (renamed from src/Compile/Exec.hs) | 2 | ||||
| -rw-r--r-- | src/CHAD/Data.hs (renamed from src/Data.hs) | 4 | ||||
| -rw-r--r-- | src/CHAD/Data/VarMap.hs (renamed from src/Data/VarMap.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Drev.hs (renamed from src/CHAD.hs) | 30 | ||||
| -rw-r--r-- | src/CHAD/Drev/Accum.hs (renamed from src/CHAD/Accum.hs) | 12 | ||||
| -rw-r--r-- | src/CHAD/Drev/EnvDescr.hs (renamed from src/CHAD/EnvDescr.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/Drev/Top.hs (renamed from src/CHAD/Top.hs) | 26 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs (renamed from src/CHAD/Types.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs (renamed from src/CHAD/Types/ToTan.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/Example.hs (renamed from src/Example.hs) | 31 | ||||
| -rw-r--r-- | src/CHAD/Example/GMM.hs (renamed from src/Example/GMM.hs) | 7 | ||||
| -rw-r--r-- | src/CHAD/Example/Types.hs (renamed from src/Example/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD.hs (renamed from src/ForwardAD.hs) | 18 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers.hs (renamed from src/ForwardAD/DualNumbers.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers/Types.hs (renamed from src/ForwardAD/DualNumbers/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/Interpreter.hs (renamed from src/Interpreter.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Accum.hs (renamed from src/Interpreter/Accum.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/AccumOld.hs (renamed from src/Interpreter/AccumOld.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Rep.hs (renamed from src/Interpreter/Rep.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/Language.hs (renamed from src/Language.hs) | 21 | ||||
| -rw-r--r-- | src/CHAD/Language/AST.hs (renamed from src/Language/AST.hs) | 12 | ||||
| -rw-r--r-- | src/CHAD/Lemmas.hs (renamed from src/Lemmas.hs) | 2 | ||||
| -rw-r--r-- | src/CHAD/Simplify.hs (renamed from src/Simplify.hs) | 16 | ||||
| -rw-r--r-- | src/CHAD/Simplify/TH.hs (renamed from src/Simplify/TH.hs) | 2 | ||||
| -rw-r--r-- | src/CHAD/Util/IdGen.hs (renamed from src/Util/IdGen.hs) | 2 | ||||
| -rw-r--r-- | test/Main.hs | 37 |
45 files changed, 292 insertions, 290 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 6db77b5..1cf97ae 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -18,20 +18,20 @@ import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) -import AST -import AST.Count -import AST.UnMonoid -import Array -import qualified CHAD (defaultConfig) -import CHAD.Top -import CHAD.Types -import Compile -import Data -import Example -import Example.GMM -import Example.Types -import Interpreter.Rep -import Simplify +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Array +import CHAD.Compile +import CHAD.Data +import qualified CHAD.Drev as CHAD (defaultConfig) +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example +import CHAD.Example.GMM +import CHAD.Example.Types +import CHAD.Interpreter.Rep +import CHAD.Simplify gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) diff --git a/chad-fast.cabal b/chad-fast.cabal index df0409d..cafce48 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -11,47 +11,47 @@ build-type: Simple library exposed-modules: -- default ghci module on top - Example + CHAD.Example - Analysis.Identity - Array - AST - AST.Accum - AST.Bindings - AST.Count - AST.Env - AST.Pretty - AST.Sparse - AST.Sparse.Types - AST.SplitLets - AST.Types - AST.UnMonoid - AST.Weaken - AST.Weaken.Auto - CHAD - CHAD.Accum - CHAD.EnvDescr - CHAD.Top - CHAD.Types - CHAD.Types.ToTan - Compile - Compile.Exec - Data - Data.VarMap - Example.GMM - Example.Types - ForwardAD - ForwardAD.DualNumbers - ForwardAD.DualNumbers.Types - Interpreter - -- Interpreter.AccumOld - Interpreter.Rep - Language - Language.AST - Lemmas - Simplify - Simplify.TH - Util.IdGen + CHAD.Analysis.Identity + CHAD.Array + CHAD.AST + CHAD.AST.Accum + CHAD.AST.Bindings + CHAD.AST.Count + CHAD.AST.Env + CHAD.AST.Pretty + CHAD.AST.Sparse + CHAD.AST.Sparse.Types + CHAD.AST.SplitLets + CHAD.AST.Types + CHAD.AST.UnMonoid + CHAD.AST.Weaken + CHAD.AST.Weaken.Auto + CHAD.Compile + CHAD.Compile.Exec + CHAD.Data + CHAD.Data.VarMap + CHAD.Drev + CHAD.Drev.Accum + CHAD.Drev.EnvDescr + CHAD.Drev.Top + CHAD.Drev.Types + CHAD.Drev.Types.ToTan + CHAD.Example.GMM + CHAD.Example.Types + CHAD.ForwardAD + CHAD.ForwardAD.DualNumbers + CHAD.ForwardAD.DualNumbers.Types + CHAD.Interpreter + -- CHAD.Interpreter.AccumOld + CHAD.Interpreter.Rep + CHAD.Language + CHAD.Language.AST + CHAD.Lemmas + CHAD.Simplify + CHAD.Simplify.TH + CHAD.Util.IdGen other-modules: build-depends: base >= 4.19 && < 4.21, diff --git a/example/Main.hs b/example/Main.hs index 6c36857..28cb7e8 100644 --- a/example/Main.hs +++ b/example/Main.hs @@ -1,6 +1,6 @@ module Main where -import Example +import CHAD.Example main :: IO () diff --git a/src/AST.hs b/src/CHAD/AST.hs index ca6cdd1..aa6aa96 100644 --- a/src/AST.hs +++ b/src/CHAD/AST.hs @@ -18,20 +18,20 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where +module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where import Data.Functor.Const import Data.Functor.Identity import Data.Int (Int64) import Data.Kind (Type) -import Array -import AST.Accum -import AST.Sparse.Types -import AST.Types -import AST.Weaken -import CHAD.Types -import Data +import CHAD.Array +import CHAD.AST.Accum +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types -- General assumption: head of the list (whatever way it is associated) is the diff --git a/src/AST/Accum.hs b/src/CHAD/AST/Accum.hs index 988a450..ea74a95 100644 --- a/src/AST/Accum.hs +++ b/src/CHAD/AST/Accum.hs @@ -5,10 +5,10 @@ {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -module AST.Accum where +module CHAD.AST.Accum where -import AST.Types -import Data +import CHAD.AST.Types +import CHAD.Data data AcPrj diff --git a/src/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs index 463586a..c1a1e77 100644 --- a/src/AST/Bindings.hs +++ b/src/CHAD/AST/Bindings.hs @@ -13,12 +13,12 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module AST.Bindings where +module CHAD.AST.Bindings where -import AST -import AST.Env -import Data -import Lemmas +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data +import CHAD.Lemmas -- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. diff --git a/src/AST/Count.hs b/src/CHAD/AST/Count.hs index ac8634e..133093a 100644 --- a/src/AST/Count.hs +++ b/src/CHAD/AST/Count.hs @@ -15,17 +15,17 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE PatternSynonyms #-} -module AST.Count where +module CHAD.AST.Count where import Data.Functor.Product import Data.Some import Data.Type.Equality import GHC.Generics (Generic, Generically(..)) -import Array -import AST -import AST.Env -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data -- | The monoid operation combines assuming that /both/ branches are taken. diff --git a/src/AST/Env.hs b/src/CHAD/AST/Env.hs index 85faba3..8e6b745 100644 --- a/src/AST/Env.hs +++ b/src/CHAD/AST/Env.hs @@ -7,14 +7,14 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} -module AST.Env where +module CHAD.AST.Env where import Data.Type.Equality -import AST.Sparse -import AST.Weaken -import CHAD.Types -import Data +import CHAD.AST.Sparse +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types -- | @env'@ is a subset of @env@: each element of @env@ is either included in diff --git a/src/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index bbcfd9e..3f6a3af 100644 --- a/src/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -8,7 +8,7 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where +module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) @@ -24,11 +24,11 @@ import System.Console.ANSI (hSupportsANSI) import System.IO (stdout) import System.IO.Unsafe (unsafePerformIO) -import AST -import AST.Count -import AST.Sparse.Types -import CHAD.Types -import Data +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types class PrettyX x where diff --git a/src/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs index 2a29799..9156160 100644 --- a/src/AST/Sparse.hs +++ b/src/CHAD/AST/Sparse.hs @@ -6,13 +6,13 @@ {-# LANGUAGE RankNTypes #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} -module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where +module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where import Data.Type.Equality -import AST -import AST.Sparse.Types -import Data (SBool(..)) +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data (SBool(..)) sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' diff --git a/src/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 10cac4e..8f41ba4 100644 --- a/src/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -3,13 +3,13 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module AST.Sparse.Types where - -import AST.Types +module CHAD.AST.Sparse.Types where import Data.Kind (Type, Constraint) import Data.Type.Equality +import CHAD.AST.Types + data Sparse t t' where SpSparse :: Sparse t t' -> Sparse t (TMaybe t') diff --git a/src/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs index 267dd87..34267e4 100644 --- a/src/AST/SplitLets.hs +++ b/src/CHAD/AST/SplitLets.hs @@ -7,13 +7,13 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module AST.SplitLets (splitLets) where +module CHAD.AST.SplitLets (splitLets) where import Data.Type.Equality -import AST -import AST.Bindings -import Lemmas +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Lemmas splitLets :: Ex env t -> Ex env t diff --git a/src/AST/Types.hs b/src/CHAD/AST/Types.hs index 4ddcb50..059077d 100644 --- a/src/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -8,7 +8,7 @@ {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module AST.Types where +module CHAD.AST.Types where import Data.Int (Int32, Int64) import Data.GADT.Compare @@ -16,7 +16,7 @@ import Data.GADT.Show import Data.Kind (Type) import Data.Type.Equality -import Data +import CHAD.Data type data Ty diff --git a/src/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index 1712ba5..27c5f0a 100644 --- a/src/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -3,11 +3,11 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where +module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where -import AST -import AST.Sparse.Types -import Data +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data -- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by diff --git a/src/AST/Weaken.hs b/src/CHAD/AST/Weaken.hs index f0820b8..ac0d152 100644 --- a/src/AST/Weaken.hs +++ b/src/CHAD/AST/Weaken.hs @@ -15,15 +15,15 @@ -- The reason why this is a separate module with "little" in it: {-# LANGUAGE AllowAmbiguousTypes #-} -module AST.Weaken (module AST.Weaken, Append) where +module CHAD.AST.Weaken (module CHAD.AST.Weaken, Append) where import Data.Bifunctor (first) import Data.Functor.Const import Data.GADT.Compare import Data.Kind (Type) -import Data -import Lemmas +import CHAD.Data +import CHAD.Lemmas type Idx :: [k] -> k -> Type diff --git a/src/AST/Weaken/Auto.hs b/src/CHAD/AST/Weaken/Auto.hs index 7370df1..14d8c59 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/CHAD/AST/Weaken/Auto.hs @@ -17,7 +17,7 @@ {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS_GHC -Wno-partial-type-signatures #-} -module AST.Weaken.Auto ( +module CHAD.AST.Weaken.Auto ( autoWeak, (&.), auto, auto1, Layout(..), @@ -29,9 +29,9 @@ import GHC.OverloadedLabels import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import AST.Weaken -import Data -import Lemmas +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Lemmas type family Lookup name list where diff --git a/src/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs index 7b896a3..212cc7d 100644 --- a/src/Analysis/Identity.hs +++ b/src/CHAD/Analysis/Identity.hs @@ -3,7 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -module Analysis.Identity ( +module CHAD.Analysis.Identity ( identityAnalysis, identityAnalysis', ValId(..), @@ -13,11 +13,11 @@ module Analysis.Identity ( import Data.Foldable (toList) import Data.List (intercalate) -import AST -import AST.Pretty (PrettyX(..)) -import CHAD.Types (d1, d2) -import Data -import Util.IdGen +import CHAD.AST +import CHAD.AST.Pretty (PrettyX(..)) +import CHAD.Data +import CHAD.Drev.Types (d1, d2) +import CHAD.Util.IdGen -- | Every array, scalar and accumulator has an ID. Trivial values such as diff --git a/src/Array.hs b/src/CHAD/Array.hs index 6ceb9fe..f80f961 100644 --- a/src/Array.hs +++ b/src/CHAD/Array.hs @@ -5,7 +5,7 @@ {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} -module Array where +module CHAD.Array where import Control.DeepSeq import Control.Monad.Trans.State.Strict @@ -14,7 +14,7 @@ import Data.Vector (Vector) import qualified Data.Vector as V import GHC.Generics (Generic) -import Data +import CHAD.Data data Shape n where diff --git a/src/Compile.hs b/src/CHAD/Compile.hs index 8627905..5b71651 100644 --- a/src/Compile.hs +++ b/src/CHAD/Compile.hs @@ -8,7 +8,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} -module Compile (compile, compileStderr) where +module CHAD.Compile (compile, compileStderr) where import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) @@ -43,14 +43,14 @@ import System.IO.Unsafe (unsafePerformIO) import Prelude hiding ((^)) import qualified Prelude -import Array -import AST -import AST.Pretty (ppSTy, ppExpr) -import AST.Sparse.Types (isDense) -import Compile.Exec -import Data -import Interpreter.Rep -import qualified Util.IdGen as IdGen +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty (ppSTy, ppExpr) +import CHAD.AST.Sparse.Types (isDense) +import CHAD.Compile.Exec +import CHAD.Data +import CHAD.Interpreter.Rep +import qualified CHAD.Util.IdGen as IdGen -- In shape and index arrays, the innermost dimension is on the right (last index). diff --git a/src/Compile/Exec.hs b/src/CHAD/Compile/Exec.hs index ad4180f..5b4afc8 100644 --- a/src/Compile/Exec.hs +++ b/src/CHAD/Compile/Exec.hs @@ -1,6 +1,6 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TupleSections #-} -module Compile.Exec ( +module CHAD.Compile.Exec ( KernelLib, buildKernel, callKernelFun, diff --git a/src/Data.hs b/src/CHAD/Data.hs index e6978c8..8c7605c 100644 --- a/src/Data.hs +++ b/src/CHAD/Data.hs @@ -8,7 +8,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl), If) where +module CHAD.Data (module CHAD.Data, (:~:)(Refl), If) where import Data.Functor.Product import Data.GADT.Compare @@ -18,7 +18,7 @@ import Data.Type.Bool (If) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) -import Lemmas (Append) +import CHAD.Lemmas (Append) data Dict c where diff --git a/src/Data/VarMap.hs b/src/CHAD/Data/VarMap.hs index 2712b08..6e16b82 100644 --- a/src/Data/VarMap.hs +++ b/src/CHAD/Data/VarMap.hs @@ -4,7 +4,7 @@ {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} -module Data.VarMap ( +module CHAD.Data.VarMap ( VarMap, empty, insert, @@ -27,9 +27,9 @@ import Data.Some import qualified Data.Vector.Storable as VS import Unsafe.Coerce -import AST.Env -import AST.Types -import AST.Weaken +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken type role VarMap _ nominal -- ensure that 'env' is not phantom diff --git a/src/CHAD.hs b/src/CHAD/Drev.hs index 298d964..595d3c7 100644 --- a/src/CHAD.hs +++ b/src/CHAD/Drev.hs @@ -23,7 +23,7 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module CHAD ( +module CHAD.Drev ( drev, freezeRet, CHADConfig(..), @@ -37,20 +37,20 @@ import Data.Functor.Const import Data.Some import Data.Type.Equality (type (==), testEquality) -import Analysis.Identity (ValId(..), validSplitEither) -import AST -import AST.Bindings -import AST.Count -import AST.Env -import AST.Sparse -import AST.Weaken.Auto -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap -import Data.VarMap (VarMap) -import Lemmas +import CHAD.Analysis.Identity (ValId(..), validSplitEither) +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.AST.Count +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.Weaken.Auto +import CHAD.Data +import qualified CHAD.Data.VarMap as VarMap +import CHAD.Data.VarMap (VarMap) +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types +import CHAD.Lemmas ------------------------------ TAPES AND BINDINGS ------------------------------ diff --git a/src/CHAD/Accum.hs b/src/CHAD/Drev/Accum.hs index a7bc53f..6f25f11 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Drev/Accum.hs @@ -1,13 +1,13 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} -- | TODO this module is a grab-bag of random utility functions that are shared --- between CHAD and CHAD.Top. -module CHAD.Accum where +-- between CHAD.Drev and CHAD.Drev.Top. +module CHAD.Drev.Accum where -import AST -import CHAD.Types -import Data -import AST.Env +import CHAD.AST +import CHAD.Data +import CHAD.Drev.Types +import CHAD.AST.Env d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs index 49ae0e6..5a90303 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/Drev/EnvDescr.hs @@ -7,18 +7,18 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.EnvDescr where +module CHAD.Drev.EnvDescr where import Data.Kind (Type) import Data.Some import GHC.TypeLits (Symbol) -import Analysis.Identity (ValId(..)) -import AST.Env -import AST.Types -import AST.Weaken -import CHAD.Types -import Data +import CHAD.Analysis.Identity (ValId(..)) +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types type Storage :: Symbol -> Type diff --git a/src/CHAD/Top.hs b/src/CHAD/Drev/Top.hs index 4814bdf..510e73e 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Drev/Top.hs @@ -8,20 +8,20 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.Top where +module CHAD.Drev.Top where -import Analysis.Identity -import AST -import AST.Env -import AST.Sparse -import AST.SplitLets -import AST.Weaken.Auto -import CHAD -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap +import CHAD.Analysis.Identity +import CHAD.AST +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.SplitLets +import CHAD.AST.Weaken.Auto +import CHAD.Data +import qualified CHAD.Data.VarMap as VarMap +import CHAD.Drev +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types type family MergeEnv env where diff --git a/src/CHAD/Types.hs b/src/CHAD/Drev/Types.hs index 44ac20e..367a974 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Drev/Types.hs @@ -2,11 +2,11 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.Types where +module CHAD.Drev.Types where -import AST.Accum -import AST.Types -import Data +import CHAD.AST.Accum +import CHAD.AST.Types +import CHAD.Data type family D1 t where diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs index 888fed4..019119c 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -1,14 +1,14 @@ {-# LANGUAGE GADTs #-} -module CHAD.Types.ToTan where +module CHAD.Drev.Types.ToTan where import Data.Bifunctor (bimap) -import Array -import AST.Types -import CHAD.Types -import Data -import ForwardAD -import Interpreter.Rep +import CHAD.Array +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.Interpreter.Rep toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) diff --git a/src/Example.hs b/src/CHAD/Example.hs index e996002..884f99a 100644 --- a/src/Example.hs +++ b/src/CHAD/Example.hs @@ -7,23 +7,24 @@ {-# LANGUAGE TypeApplications #-} {-# OPTIONS -Wno-unused-imports #-} -module Example where - -import Array -import AST -import AST.Count -import AST.Pretty -import AST.UnMonoid -import CHAD -import CHAD.Top -import CHAD.Types -import ForwardAD -import Interpreter -import Language -import Simplify +module CHAD.Example where import Debug.Trace -import Example.Types + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Data +import CHAD.Drev +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.Interpreter +import CHAD.Language +import CHAD.Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) diff --git a/src/Example/GMM.hs b/src/CHAD/Example/GMM.hs index 206e534..8f834e0 100644 --- a/src/Example/GMM.hs +++ b/src/CHAD/Example/GMM.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeApplications #-} -module Example.GMM where +module CHAD.Example.GMM where -import Example.Types -import Language +import CHAD.Data (SList(..), SNat(..)) +import CHAD.Example.Types +import CHAD.Language diff --git a/src/Example/Types.hs b/src/CHAD/Example/Types.hs index d63159b..1e2f72d 100644 --- a/src/Example/Types.hs +++ b/src/CHAD/Example/Types.hs @@ -1,8 +1,8 @@ {-# LANGUAGE DataKinds #-} -module Example.Types where +module CHAD.Example.Types where -import AST -import Data +import CHAD.AST +import CHAD.Data type R = TScal TF64 diff --git a/src/ForwardAD.hs b/src/CHAD/ForwardAD.hs index 6655423..7126e10 100644 --- a/src/ForwardAD.hs +++ b/src/CHAD/ForwardAD.hs @@ -4,21 +4,21 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD where +module CHAD.ForwardAD where import Data.Bifunctor (bimap) import System.IO.Unsafe -- import Debug.Trace --- import AST.Pretty +-- import CHAD.AST.Pretty -import Array -import AST -import Compile -import Data -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep +import CHAD.Array +import CHAD.AST +import CHAD.Compile +import CHAD.Data +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep -- | Tangent along a type (coincides with cotangent for these types) diff --git a/src/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs index a1e9d0d..a71efc8 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/CHAD/ForwardAD/DualNumbers.hs @@ -14,14 +14,14 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD.DualNumbers ( +module CHAD.ForwardAD.DualNumbers ( dfwdDN, DN, DNS, DNE, dn, dne, ) where -import AST -import Data -import ForwardAD.DualNumbers.Types +import CHAD.AST +import CHAD.Data +import CHAD.ForwardAD.DualNumbers.Types dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs index dcacf5f..5d5dd9e 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs @@ -1,10 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD.DualNumbers.Types where +module CHAD.ForwardAD.DualNumbers.Types where -import AST.Types -import Data +import CHAD.AST.Types +import CHAD.Data -- | Dual-numbers transformation diff --git a/src/Interpreter.hs b/src/CHAD/Interpreter.hs index e1c81cd..a9421e6 100644 --- a/src/Interpreter.hs +++ b/src/CHAD/Interpreter.hs @@ -14,7 +14,7 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Interpreter ( +module CHAD.Interpreter ( interpret, interpretOpen, Value(..), @@ -36,12 +36,12 @@ import System.IO.Unsafe (unsafePerformIO) import Debug.Trace -import Array -import AST -import AST.Pretty -import AST.Sparse.Types -import Data -import Interpreter.Rep +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Interpreter.Rep newtype AcM s a = AcM { unAcM :: IO a } diff --git a/src/Interpreter/Accum.hs b/src/CHAD/Interpreter/Accum.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/Accum.hs +++ b/src/CHAD/Interpreter/Accum.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/AccumOld.hs b/src/CHAD/Interpreter/AccumOld.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/AccumOld.hs +++ b/src/CHAD/Interpreter/AccumOld.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs index 1682303..fadc6be 100644 --- a/src/Interpreter/Rep.hs +++ b/src/CHAD/Interpreter/Rep.hs @@ -3,7 +3,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -module Interpreter.Rep where +module CHAD.Interpreter.Rep where import Control.DeepSeq import Data.Coerce (coerce) @@ -12,10 +12,10 @@ import Data.Foldable (toList) import Data.IORef import GHC.Exts (withDict) -import Array -import AST -import AST.Pretty -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.Data type family Rep t where diff --git a/src/Language.hs b/src/CHAD/Language.hs index 4886ddc..6dc91a5 100644 --- a/src/Language.hs +++ b/src/CHAD/Language.hs @@ -6,25 +6,24 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} -module Language ( +module CHAD.Language ( fromNamed, NExpr, Ex, - module Language, - module AST.Types, - module Data, + module CHAD.Language, + module CHAD.AST.Types, Lookup, ) where import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) -import Array -import AST -import AST.Sparse.Types -import AST.Types -import CHAD.Types -import Data -import Language.AST +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.Language.AST data a :-> b = a :-> b diff --git a/src/Language/AST.hs b/src/CHAD/Language/AST.hs index 3d6ede5..b270844 100644 --- a/src/Language/AST.hs +++ b/src/CHAD/Language/AST.hs @@ -14,18 +14,18 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module Language.AST where +module CHAD.Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) -import Array -import AST -import AST.Sparse.Types -import CHAD.Types -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types type NExpr :: [(Symbol, Ty)] -> Ty -> Type diff --git a/src/Lemmas.hs b/src/CHAD/Lemmas.hs index 31a43ed..55ef042 100644 --- a/src/Lemmas.hs +++ b/src/CHAD/Lemmas.hs @@ -4,7 +4,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE AllowAmbiguousTypes #-} -module Lemmas (module Lemmas, (:~:)(Refl)) where +module CHAD.Lemmas (module CHAD.Lemmas, (:~:)(Refl)) where import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) diff --git a/src/Simplify.hs b/src/CHAD/Simplify.hs index 19d0c17..2510cc5 100644 --- a/src/Simplify.hs +++ b/src/CHAD/Simplify.hs @@ -12,7 +12,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Simplify ( +module CHAD.Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, ) where @@ -24,13 +24,13 @@ import Data.Monoid (Any(..)) import Debug.Trace -import AST -import AST.Count -import AST.Pretty -import AST.Sparse.Types -import AST.UnMonoid (acPrjCompose) -import Data -import Simplify.TH +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.AST.UnMonoid (acPrjCompose) +import CHAD.Data +import CHAD.Simplify.TH data SimplifyConfig = SimplifyConfig diff --git a/src/Simplify/TH.hs b/src/CHAD/Simplify/TH.hs index 03a74de..4af5394 100644 --- a/src/Simplify/TH.hs +++ b/src/CHAD/Simplify/TH.hs @@ -1,5 +1,5 @@ {-# LANGUAGE TemplateHaskellQuotes #-} -module Simplify.TH (simprec) where +module CHAD.Simplify.TH (simprec) where import Data.Bifunctor (first) import Data.Char diff --git a/src/Util/IdGen.hs b/src/CHAD/Util/IdGen.hs index 3f6611d..d4fd945 100644 --- a/src/Util/IdGen.hs +++ b/src/CHAD/Util/IdGen.hs @@ -1,6 +1,6 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -module Util.IdGen where +module CHAD.Util.IdGen where import Control.Monad.Fix import Control.Monad.Trans.State.Strict diff --git a/test/Main.hs b/test/Main.hs index 358073f..652de02 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,24 +24,25 @@ import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Test.Framework -import Array -import AST hiding ((.>)) -import AST.Count (pruneExpr) -import AST.Pretty -import AST.UnMonoid -import CHAD.Top -import CHAD.Types -import CHAD.Types.ToTan -import Compile -import qualified Example -import qualified Example.GMM as Example -import Example.Types -import ForwardAD -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep -import Language -import Simplify +import CHAD.Array +import CHAD.AST hiding ((.>)) +import CHAD.AST.Count (pruneExpr) +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Drev.Types.ToTan +import qualified CHAD.Example as Example +import qualified CHAD.Example.GMM as Example +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep +import CHAD.Language +import CHAD.Simplify data TypedValue t = TypedValue (STy t) (Rep t) |
