diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/Example.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/Example.hs')
| -rw-r--r-- | src/Example.hs | 196 |
1 files changed, 0 insertions, 196 deletions
diff --git a/src/Example.hs b/src/Example.hs deleted file mode 100644 index e996002..0000000 --- a/src/Example.hs +++ /dev/null @@ -1,196 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} - -{-# OPTIONS -Wno-unused-imports #-} -module Example where - -import Array -import AST -import AST.Count -import AST.Pretty -import AST.UnMonoid -import CHAD -import CHAD.Top -import CHAD.Types -import ForwardAD -import Interpreter -import Language -import Simplify - -import Debug.Trace -import Example.Types - - --- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) - - -pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) -pipeline config term - | Dict <- styKnown (d2 (typeOf term)) = - simplifyFix $ pruneExpr knownEnv $ - simplifyFix $ unMonoid $ - simplifyFix $ chad' config knownEnv $ - simplifyFix $ term - --- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures -pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () -pipeline' config term - | Dict <- styKnown (d2 (typeOf term)) = - pprintExpr (pipeline config term) - - -bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c -bin op a b = EOp ext op (EPair ext a b) - -senv1 :: SList STy [TScal TF32, TScal TF32] -senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil - --- x y |- x * y + x --- --- let x3 = (x1, x2) --- x4 = ((*) x3, x1) --- in ( (+) x4 --- , let x5 = 1.0 --- x6 = Inr (x5, x5) --- in case x6 of --- Inl x7 -> return () --- Inr x8 -> --- let x9 = fst x8 --- x10 = Inr (snd x3 * x9, fst x3 * x9) --- in case x10 of --- Inl x11 -> return () --- Inr x12 -> --- let x13 = fst x12 --- in one "v1" x13 >>= \x14 -> --- let x15 = snd x12 --- in one "v2" x15 >>= \x16 -> --- let x17 = snd x8 --- in one "v1" x17) --- --- ( (x1 * x2) + x1 --- , let x5 = 1.0 --- in do one "v1" (x2 * x5) --- one "v2" (x1 * x5) --- one "v1" x5) -ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex1 = fromNamed $ lambda #x $ lambda #y $ body $ - #x * #y + #x - --- x y |- let z = x + y in z * (z + x) -ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex2 = fromNamed $ lambda #x $ lambda #y $ body $ - let_ #z (#x + #y) $ - #z * (#z + #x) - --- x y |- if x < y then 2 * x else 3 + x -ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex3 = fromNamed $ lambda #x $ lambda #y $ body $ - if_ (#x .< #y) (2 * #x) (3 * #x) - --- x y |- if x < y then 2 * x + y * y else 3 + x -ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex4 = fromNamed $ lambda #x $ lambda #y $ body $ - if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) - --- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} -ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) -ex5 = fromNamed $ lambda #x $ lambda #y $ body $ - case_ #x (#a :-> #a * #y) - (#b :-> #b * (#y + 1)) - --- x:R n:I |- let a = unit x --- b = build1 n (\i. let c = idx0 a in c * c) --- in idx0 (b ! 3) -ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) -ex6 = fromNamed $ lambda #x $ lambda #n $ body $ - let_ #a (unit #x) $ - let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ - #b ! pair nil 3 - --- A "neural network" except it's just scalars, not matrices. --- ps:((((), (R,R)), (R,R)), (R,R)) x:R --- |- let p1 = snd ps --- p1' = fst ps --- x1 = fst p1 * x + snd p1 --- p2 = snd p1' --- p2' = fst p1' --- x2 = fst p2 * x + snd p2 --- p3 = snd p2' --- p3' = fst p2' --- x3 = fst p3 * x + snd p3 --- in x3 -ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R -ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ - let tR = STScal STF64 - tpair = STPair tR tR - - layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R) - => STy p -> NExpr env R - layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t = - let_ #par (snd_ #parstup) $ - let_ #restpars (fst_ #parstup) $ - let_ #inp (fst_ #par * #inp + snd_ #par) $ - let_ #parstup #restpars $ - layer t - layer STNil = #inp - layer _ = error "Invalid layer inputs" - - in let_ #parstup #pars123 $ - let_ #inp #input $ - layer (STPair (STPair (STPair STNil tpair) tpair) tpair) - -neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R -neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ - let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ - -- prod = wei `matmul` x - let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> - #wei ! #idx * #x ! pair nil (snd_ #idx)) $ - -- relu (prod + bias) - build (SS SZ) (shape #prod) $ #idx :-> - let_ #out (#prod ! #idx + #bias ! #idx) $ - if_ (#out .<= const_ 0) (const_ 0) #out - - in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ - let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ - let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ - #x3 ! nil - -type NeuralGrad = ((Array N2 Double, Array N1 Double) - ,(Array N2 Double, Array N1 Double) - ,Array N1 Double - ,Array N1 Double) - -neuralGo :: (Double -- primal - ,NeuralGrad -- gradient using CHAD - ,NeuralGrad) -- gradient using dual-numbers forward AD -neuralGo = - let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) - lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) - lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] - input = arrayFromList (ShNil `ShCons` 2) [1,1] - argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) - revderiv = - simplifyN 20 $ - ELet ext (EConst ext STF64 1.0) $ - chad defaultConfig knownEnv neural - (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of - (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 - in trace (ppExpr knownEnv revderiv) $ - (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) - --- The build body uses free variables in a non-linear way, so their primal --- values are required in the dual of the build. Thus, compositionally, they --- are stored in the tape from each individual lambda invocation. This results --- in n copies of y and z, where only one copy would have sufficed. -exUniformFree :: Ex '[R, I64] R -exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ - let_ #y (#x * 2) $ - let_ #z (#x * 3) $ - idx0 $ sum1i $ - build1 #n $ #i :-> #y * #z + toFloat_ #i |
