aboutsummaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs196
1 files changed, 0 insertions, 196 deletions
diff --git a/src/Example.hs b/src/Example.hs
deleted file mode 100644
index e996002..0000000
--- a/src/Example.hs
+++ /dev/null
@@ -1,196 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE OverloadedLabels #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE TypeApplications #-}
-
-{-# OPTIONS -Wno-unused-imports #-}
-module Example where
-
-import Array
-import AST
-import AST.Count
-import AST.Pretty
-import AST.UnMonoid
-import CHAD
-import CHAD.Top
-import CHAD.Types
-import ForwardAD
-import Interpreter
-import Language
-import Simplify
-
-import Debug.Trace
-import Example.Types
-
-
--- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
-
-
-pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
-pipeline config term
- | Dict <- styKnown (d2 (typeOf term)) =
- simplifyFix $ pruneExpr knownEnv $
- simplifyFix $ unMonoid $
- simplifyFix $ chad' config knownEnv $
- simplifyFix $ term
-
--- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures
-pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO ()
-pipeline' config term
- | Dict <- styKnown (d2 (typeOf term)) =
- pprintExpr (pipeline config term)
-
-
-bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
-bin op a b = EOp ext op (EPair ext a b)
-
-senv1 :: SList STy [TScal TF32, TScal TF32]
-senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
-
--- x y |- x * y + x
---
--- let x3 = (x1, x2)
--- x4 = ((*) x3, x1)
--- in ( (+) x4
--- , let x5 = 1.0
--- x6 = Inr (x5, x5)
--- in case x6 of
--- Inl x7 -> return ()
--- Inr x8 ->
--- let x9 = fst x8
--- x10 = Inr (snd x3 * x9, fst x3 * x9)
--- in case x10 of
--- Inl x11 -> return ()
--- Inr x12 ->
--- let x13 = fst x12
--- in one "v1" x13 >>= \x14 ->
--- let x15 = snd x12
--- in one "v2" x15 >>= \x16 ->
--- let x17 = snd x8
--- in one "v1" x17)
---
--- ( (x1 * x2) + x1
--- , let x5 = 1.0
--- in do one "v1" (x2 * x5)
--- one "v2" (x1 * x5)
--- one "v1" x5)
-ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex1 = fromNamed $ lambda #x $ lambda #y $ body $
- #x * #y + #x
-
--- x y |- let z = x + y in z * (z + x)
-ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex2 = fromNamed $ lambda #x $ lambda #y $ body $
- let_ #z (#x + #y) $
- #z * (#z + #x)
-
--- x y |- if x < y then 2 * x else 3 + x
-ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex3 = fromNamed $ lambda #x $ lambda #y $ body $
- if_ (#x .< #y) (2 * #x) (3 * #x)
-
--- x y |- if x < y then 2 * x + y * y else 3 + x
-ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex4 = fromNamed $ lambda #x $ lambda #y $ body $
- if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)
-
--- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
-ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
-ex5 = fromNamed $ lambda #x $ lambda #y $ body $
- case_ #x (#a :-> #a * #y)
- (#b :-> #b * (#y + 1))
-
--- x:R n:I |- let a = unit x
--- b = build1 n (\i. let c = idx0 a in c * c)
--- in idx0 (b ! 3)
-ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
-ex6 = fromNamed $ lambda #x $ lambda #n $ body $
- let_ #a (unit #x) $
- let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
- #b ! pair nil 3
-
--- A "neural network" except it's just scalars, not matrices.
--- ps:((((), (R,R)), (R,R)), (R,R)) x:R
--- |- let p1 = snd ps
--- p1' = fst ps
--- x1 = fst p1 * x + snd p1
--- p2 = snd p1'
--- p2' = fst p1'
--- x2 = fst p2 * x + snd p2
--- p3 = snd p2'
--- p3' = fst p2'
--- x3 = fst p3 * x + snd p3
--- in x3
-ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
-ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
- let tR = STScal STF64
- tpair = STPair tR tR
-
- layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R)
- => STy p -> NExpr env R
- layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t =
- let_ #par (snd_ #parstup) $
- let_ #restpars (fst_ #parstup) $
- let_ #inp (fst_ #par * #inp + snd_ #par) $
- let_ #parstup #restpars $
- layer t
- layer STNil = #inp
- layer _ = error "Invalid layer inputs"
-
- in let_ #parstup #pars123 $
- let_ #inp #input $
- layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
-
-neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
-neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
- let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
- -- prod = wei `matmul` x
- let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
- #wei ! #idx * #x ! pair nil (snd_ #idx)) $
- -- relu (prod + bias)
- build (SS SZ) (shape #prod) $ #idx :->
- let_ #out (#prod ! #idx + #bias ! #idx) $
- if_ (#out .<= const_ 0) (const_ 0) #out
-
- in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
- let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
- let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
- #x3 ! nil
-
-type NeuralGrad = ((Array N2 Double, Array N1 Double)
- ,(Array N2 Double, Array N1 Double)
- ,Array N1 Double
- ,Array N1 Double)
-
-neuralGo :: (Double -- primal
- ,NeuralGrad -- gradient using CHAD
- ,NeuralGrad) -- gradient using dual-numbers forward AD
-neuralGo =
- let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
- lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
- lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
- input = arrayFromList (ShNil `ShCons` 2) [1,1]
- argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
- revderiv =
- simplifyN 20 $
- ELet ext (EConst ext STF64 1.0) $
- chad defaultConfig knownEnv neural
- (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
- (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
- (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
- in trace (ppExpr knownEnv revderiv) $
- (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
-
--- The build body uses free variables in a non-linear way, so their primal
--- values are required in the dual of the build. Thus, compositionally, they
--- are stored in the tape from each individual lambda invocation. This results
--- in n copies of y and z, where only one copy would have sufficed.
-exUniformFree :: Ex '[R, I64] R
-exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $
- let_ #y (#x * 2) $
- let_ #z (#x * 3) $
- idx0 $ sum1i $
- build1 #n $ #i :-> #y * #z + toFloat_ #i