diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Example.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Example.hs')
| -rw-r--r-- | src/CHAD/Example.hs | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/src/CHAD/Example.hs b/src/CHAD/Example.hs new file mode 100644 index 0000000..884f99a --- /dev/null +++ b/src/CHAD/Example.hs @@ -0,0 +1,197 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} + +{-# OPTIONS -Wno-unused-imports #-} +module CHAD.Example where + +import Debug.Trace + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Data +import CHAD.Drev +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.Interpreter +import CHAD.Language +import CHAD.Simplify + + +-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) + + +pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +pipeline config term + | Dict <- styKnown (d2 (typeOf term)) = + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ chad' config knownEnv $ + simplifyFix $ term + +-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures +pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () +pipeline' config term + | Dict <- styKnown (d2 (typeOf term)) = + pprintExpr (pipeline config term) + + +bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c +bin op a b = EOp ext op (EPair ext a b) + +senv1 :: SList STy [TScal TF32, TScal TF32] +senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil + +-- x y |- x * y + x +-- +-- let x3 = (x1, x2) +-- x4 = ((*) x3, x1) +-- in ( (+) x4 +-- , let x5 = 1.0 +-- x6 = Inr (x5, x5) +-- in case x6 of +-- Inl x7 -> return () +-- Inr x8 -> +-- let x9 = fst x8 +-- x10 = Inr (snd x3 * x9, fst x3 * x9) +-- in case x10 of +-- Inl x11 -> return () +-- Inr x12 -> +-- let x13 = fst x12 +-- in one "v1" x13 >>= \x14 -> +-- let x15 = snd x12 +-- in one "v2" x15 >>= \x16 -> +-- let x17 = snd x8 +-- in one "v1" x17) +-- +-- ( (x1 * x2) + x1 +-- , let x5 = 1.0 +-- in do one "v1" (x2 * x5) +-- one "v2" (x1 * x5) +-- one "v1" x5) +ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex1 = fromNamed $ lambda #x $ lambda #y $ body $ + #x * #y + #x + +-- x y |- let z = x + y in z * (z + x) +ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex2 = fromNamed $ lambda #x $ lambda #y $ body $ + let_ #z (#x + #y) $ + #z * (#z + #x) + +-- x y |- if x < y then 2 * x else 3 + x +ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex3 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x) (3 * #x) + +-- x y |- if x < y then 2 * x + y * y else 3 + x +ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex4 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) + +-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} +ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) +ex5 = fromNamed $ lambda #x $ lambda #y $ body $ + case_ #x (#a :-> #a * #y) + (#b :-> #b * (#y + 1)) + +-- x:R n:I |- let a = unit x +-- b = build1 n (\i. let c = idx0 a in c * c) +-- in idx0 (b ! 3) +ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) +ex6 = fromNamed $ lambda #x $ lambda #n $ body $ + let_ #a (unit #x) $ + let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ + #b ! pair nil 3 + +-- A "neural network" except it's just scalars, not matrices. +-- ps:((((), (R,R)), (R,R)), (R,R)) x:R +-- |- let p1 = snd ps +-- p1' = fst ps +-- x1 = fst p1 * x + snd p1 +-- p2 = snd p1' +-- p2' = fst p1' +-- x2 = fst p2 * x + snd p2 +-- p3 = snd p2' +-- p3' = fst p2' +-- x3 = fst p3 * x + snd p3 +-- in x3 +ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R +ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ + let tR = STScal STF64 + tpair = STPair tR tR + + layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R) + => STy p -> NExpr env R + layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t = + let_ #par (snd_ #parstup) $ + let_ #restpars (fst_ #parstup) $ + let_ #inp (fst_ #par * #inp + snd_ #par) $ + let_ #parstup #restpars $ + layer t + layer STNil = #inp + layer _ = error "Invalid layer inputs" + + in let_ #parstup #pars123 $ + let_ #inp #input $ + layer (STPair (STPair (STPair STNil tpair) tpair) tpair) + +neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R +neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ + let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ + -- prod = wei `matmul` x + let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> + #wei ! #idx * #x ! pair nil (snd_ #idx)) $ + -- relu (prod + bias) + build (SS SZ) (shape #prod) $ #idx :-> + let_ #out (#prod ! #idx + #bias ! #idx) $ + if_ (#out .<= const_ 0) (const_ 0) #out + + in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ + let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ + let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ + #x3 ! nil + +type NeuralGrad = ((Array N2 Double, Array N1 Double) + ,(Array N2 Double, Array N1 Double) + ,Array N1 Double + ,Array N1 Double) + +neuralGo :: (Double -- primal + ,NeuralGrad -- gradient using CHAD + ,NeuralGrad) -- gradient using dual-numbers forward AD +neuralGo = + let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) + lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) + lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] + input = arrayFromList (ShNil `ShCons` 2) [1,1] + argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) + revderiv = + simplifyN 20 $ + ELet ext (EConst ext STF64 1.0) $ + chad defaultConfig knownEnv neural + (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') + (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 + in trace (ppExpr knownEnv revderiv) $ + (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) + +-- The build body uses free variables in a non-linear way, so their primal +-- values are required in the dual of the build. Thus, compositionally, they +-- are stored in the tape from each individual lambda invocation. This results +-- in n copies of y and z, where only one copy would have sufficed. +exUniformFree :: Ex '[R, I64] R +exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ + let_ #y (#x * 2) $ + let_ #z (#x * 3) $ + idx0 $ sum1i $ + build1 #n $ #i :-> #y * #z + toFloat_ #i |
