diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/Interpreter.hs | |
parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 58d79a5..f8e7e98 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -99,6 +99,15 @@ interpret'Rec env = \case EMaybe _ a b e -> let STMaybe t1 = typeOf e in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e + ELNil _ _ _ -> return Nothing + ELInl _ _ e -> Just . Left <$> interpret' env e + ELInr _ _ e -> Just . Right <$> interpret' env e + ELCase _ e a b c -> + let STLEither t1 t2 = typeOf e + in interpret' env e >>= \case + Nothing -> interpret' env a + Just (Left x) -> interpret' (V t1 x `SCons` env) b + Just (Right y) -> interpret' (V t2 y `SCons` env) c EConstArr _ _ _ v -> return v EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a @@ -136,9 +145,9 @@ interpret'Rec env = \case EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ a b - | STArr n _ <- typeOf a - -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EIdx _ a b -> + let STArr n _ = typeOf a + in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e ECustom _ t1 t2 _ pr _ _ e1 e2 -> do @@ -154,8 +163,8 @@ interpret'Rec env = \case val <- interpret' env e2 accum <- interpret' env e3 accumAddSparse t p accum idx val - EZero _ t -> do - return $ zeroD2 t + EZero _ t ezi -> do + return $ zeroD2 t ezi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b @@ -250,7 +259,7 @@ onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = +onehotD2 (SAPArrIdx prj) (STArr n a) idx val = Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) @@ -299,7 +308,7 @@ newAcSparse typ prj idx val = case (typ, prj) of (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val + (STArr n t, SAPArrIdx prj') -> newIORef . Just =<< newAcArray n t prj' idx val (STAccum{}, _) -> error "Accumulators not allowed in source program" @@ -380,7 +389,7 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of (arrayMapM (newAcSparse t1 SAPHere ()) val') (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) - (STArr n t1, SAPArrIdx prj' _) -> + (STArr n t1, SAPArrIdx prj') -> let ((arrindex', arrsh'), idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = unTupRepIdx ShNil ShCons n arrsh' |