summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-26 23:05:30 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-26 23:05:30 +0100
commitade38c607a8d0dc8dc1d701084ed88df2fa89df9 (patch)
tree2183d63164a27fe84bb00b6c1920fe6c2be1e0e8
parentae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (diff)
Working argument accum mode (...)
The derivative of 'neural' in full accum mode is pretty atrocious now; I think this is because when you have code like this: \(a :: Arr 1 R) -> let b = a in let c = b in sum d then because the argument, as well as both let bindings, bind a value of array type, each will introduce an accumulator, hence resulting in three (!) nested `with` clauses that each just contribute their result back to their parent. This is pointless, and we should fix this.
-rw-r--r--src/AST/Pretty.hs2
-rw-r--r--src/Interpreter.hs25
-rw-r--r--src/Interpreter/Rep.hs2
-rw-r--r--test/Main.hs8
4 files changed, 25 insertions, 12 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index ec8574f..663e9b0 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -218,7 +218,7 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
- EZero _ -> return $ showString "zero"
+ EZero t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")")
EPlus _ a b -> do
a' <- ppExpr' 11 val a
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 56ebf82..bb4952c 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -30,7 +30,6 @@ import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import Debug.Trace
-import GHC.Stack
import Array
import AST
@@ -250,7 +249,6 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
- putStrLn $ "withAccum: " ++ show t
accum <- newAcSparse t SZ () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
@@ -324,7 +322,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
+newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
STNil -> return ()
STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
@@ -372,13 +370,19 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
go mk dep' dim' idx' val' $ \arr pish ->
k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
-newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
+newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
newAcDense typ SZ () val = case typ of
+ STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> newAcSparse t1 SZ () x
Right y -> Right <$> newAcSparse t2 SZ () y
_ -> error "newAcDense: invalid dense type"
newAcDense typ (SS dep) idx val = case typ of
+ STPair t1 t2 ->
+ case (idx, val) of
+ (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
+ (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
+ _ -> error "Index/value mismatch in newAc pair"
STEither t1 t2 ->
case (idx, val) of
(Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
@@ -400,6 +404,7 @@ readAcSparse typ val = case typ of
readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
readAcDense typ val = case typ of
+ STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> readAcSparse t1 x
Right y -> Right <$> readAcSparse t2 y
@@ -505,19 +510,27 @@ accumAddSparse typ (SS dep) ref idx val = case typ of
accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
accumAddDense typ SZ ref () val = case typ of
+ STPair t1 t2 ->
+ (ref, do accumAddSparse t1 SZ (fst ref) () (fst val)
+ accumAddSparse t2 SZ (snd ref) () (snd val))
STEither t1 t2 ->
case (ref, val) of
(Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val')
(Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val')
- _ -> error "Mismatched Either in accumAdd"
+ _ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"
accumAddDense typ (SS dep) ref idx val = case typ of
+ STPair t1 t2 ->
+ case (idx, val) of
+ (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val')
+ (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val')
+ _ -> error "Mismatched Either in accumAddDense pair"
STEither t1 t2 ->
case (ref, idx, val) of
(Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val')
(Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
- _ -> error "Mismatched Either in accumAdd"
+ _ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 0007991..335ad1f 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -37,7 +37,7 @@ type family RepAcSparse t where
-- Immutable, and does not necessarily have a zero.
type family RepAcDense t where
RepAcDense TNil = ()
- -- RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b)
+ RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b)
RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b)
-- RepAcDense (TMaybe t) = RepAcSparse (TMaybe t) -- ^ This can be optimised to TMaybe (RepAcSparse t), but that makes accumAddDense very hard to write. And in any case, we don't need it because D2 will not produce Maybe of Maybe.
-- RepAcDense (TArr n t) = Array n (RepAcSparse t)
diff --git a/test/Main.hs b/test/Main.hs
index d18884e..b6f9f2b 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -62,11 +62,11 @@ gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input
toTan typ primal der = case typ of
STNil -> der
STPair t1 t2 -> case der of
- Left () -> bimap (zeroTan t1) (zeroTan t2) primal
- Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
+ Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
STEither t1 t2 -> case der of
- Left () -> bimap (zeroTan t1) (zeroTan t2) primal
- Right d -> case (primal, d) of
+ Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
+ Just d -> case (primal, d) of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"