diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
commit | a00234388d1b4e14481067d030bf90031258b756 (patch) | |
tree | 501b6778fc5779ce220aba1e22f56ae60f68d970 | |
parent | 7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff) |
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
-rw-r--r-- | src/AST.hs | 22 | ||||
-rw-r--r-- | src/AST/Types.hs | 51 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 16 | ||||
-rw-r--r-- | src/CHAD.hs | 75 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 4 | ||||
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 16 | ||||
-rw-r--r-- | src/Compile.hs | 221 | ||||
-rw-r--r-- | src/Data.hs | 5 | ||||
-rw-r--r-- | src/Example.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 61 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 3 | ||||
-rw-r--r-- | src/Util/IdGen.hs | 3 |
12 files changed, 285 insertions, 194 deletions
@@ -246,28 +246,6 @@ extOf = \case EOneHot x _ _ _ _ -> x EError x _ _ -> x -unSTy :: STy t -> Ty -unSTy = \case - STNil -> TNil - STPair a b -> TPair (unSTy a) (unSTy b) - STEither a b -> TEither (unSTy a) (unSTy b) - STMaybe t -> TMaybe (unSTy t) - STArr n t -> TArr (unSNat n) (unSTy t) - STScal t -> TScal (unSScalTy t) - STAccum t -> TAccum (unSTy t) - -unSEnv :: SList STy env -> [Ty] -unSEnv SNil = [] -unSEnv (SCons t l) = unSTy t : unSEnv l - -unSScalTy :: SScalTy t -> ScalTy -unSScalTy = \case - STI32 -> TI32 - STI64 -> TI64 - STF32 -> TF32 - STF64 -> TF64 - STBool -> TBool - subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t subst1 repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i diff --git a/src/AST/Types.hs b/src/AST/Types.hs index 0b41671..217b2f5 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -8,7 +9,9 @@ module AST.Types where import Data.Int (Int32, Int64) +import Data.GADT.Show import Data.Kind (Type) +import Data.Some import Data.Type.Equality import Data @@ -54,6 +57,8 @@ instance TestEquality STy where testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl testEquality STAccum{} _ = Nothing +instance GShow STy where gshowsPrec = defaultGshowsPrec + data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 @@ -70,6 +75,8 @@ instance TestEquality SScalTy where testEquality STBool STBool = Just Refl testEquality _ _ = Nothing +instance GShow SScalTy where gshowsPrec = defaultGshowsPrec + scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) scalRepIsShow STI32 = Dict scalRepIsShow STI64 = Dict @@ -82,6 +89,50 @@ type TIx = TScal TI64 tIx :: STy TIx tIx = STScal STI64 +unSTy :: STy t -> Ty +unSTy = \case + STNil -> TNil + STPair a b -> TPair (unSTy a) (unSTy b) + STEither a b -> TEither (unSTy a) (unSTy b) + STMaybe t -> TMaybe (unSTy t) + STArr n t -> TArr (unSNat n) (unSTy t) + STScal t -> TScal (unSScalTy t) + STAccum t -> TAccum (unSTy t) + +unSEnv :: SList STy env -> [Ty] +unSEnv SNil = [] +unSEnv (SCons t l) = unSTy t : unSEnv l + +unSScalTy :: SScalTy t -> ScalTy +unSScalTy = \case + STI32 -> TI32 + STI64 -> TI64 + STF32 -> TF32 + STF64 -> TF64 + STBool -> TBool + +reSTy :: Ty -> Some STy +reSTy = \case + TNil -> Some STNil + TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b' + TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b' + TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t' + TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t' + TScal t | Some t' <- reSScalTy t -> Some $ STScal t' + TAccum t | Some t' <- reSTy t -> Some $ STAccum t' + +reSEnv :: [Ty] -> Some (SList STy) +reSEnv [] = Some SNil +reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env) + +reSScalTy :: ScalTy -> Some SScalTy +reSScalTy = \case + TI32 -> Some STI32 + TI64 -> Some STI64 + TF32 -> Some STF32 + TF64 -> Some STF64 + TBool -> Some STBool + type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index b30f7a0..0da1afc 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -51,8 +51,8 @@ zero STNil = ENil ext zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr SZ t) = EUnit ext (zero t) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty") +zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t)) +zero (STArr n t) = ENothing ext (STArr n (d2 t)) zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -84,8 +84,7 @@ plus (STMaybe t) a b = plusSparse (d2 t) a b $ plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) plus (STArr n t) a b = - ELet ext a $ - ELet ext (weakenExpr WSink b) $ + plusSparse (STArr n (d2 t)) a b $ eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) (EVar ext (STArr n (d2 t)) IZ) (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) @@ -131,7 +130,8 @@ onehot typ topprj idx arg = case (typ, topprj) of (STArr n t1, SAPArrIdx prj _) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ - EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (zero t1) + EJust ext $ + EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (zero t1) diff --git a/src/CHAD.hs b/src/CHAD.hs index a5a5719..be308cd 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -846,18 +846,18 @@ drev des = \case (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ))) + (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in + EMaybe ext (zeroTup envPro) (ESnd ext $ uninvertTup (d2e envPro) (STArr ndim STNil) $ makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ -- the cotangent for this element ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) (EVar ext shty IZ)) $ -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) (EVar ext shty (IS IZ))) $ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ in letBinds rebinds $ @@ -865,17 +865,19 @@ drev des = \case &. #pro (d2ace envPro) &. #etape (subList (bindingsBinds e0) subtapeE) &. #prerebinds prerebinds - &. #tape (tapety `SCons` SNil) - &. #ix (shty `SCons` SNil) - &. #darr (STArr ndim (d2 eltty) `SCons` SNil) - &. #tapearr (STArr ndim tapety `SCons` SNil) - &. #sh (shty `SCons` SNil) + &. #tape (auto1 @(Tape e_tape)) + &. #ix (auto1 @shty) + &. #darr (auto1 @(TArr ndim (D2 eltty))) + &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) + &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) + &. #sh (auto1 @shty) &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des))) (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) .> wPro (subList (bindingsBinds e0) subtapeE)) - e2)) + e2) + (EVar ext (d2 (STArr ndim eltty)) IZ)) }} EUnit _ e @@ -884,8 +886,11 @@ drev des = \case subtape (EUnit ext e1) sub - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy WSink) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) EReplicate1Inner _ en e -- We're allowed to ignore en2 here because the output of 'ei' is discrete. @@ -896,11 +901,14 @@ drev des = \case subtape (EReplicate1Inner ext en1 e1) sub - (ELet ext (EFold1Inner ext Commut - (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero ext eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ - weakenExpr (WCopy WSink) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EJust ext (EFold1Inner ext Commut + (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero ext eltty) + (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) EIdx0 _ e | Ret e0 subtape e1 sub e2 <- drev des e @@ -909,7 +917,7 @@ drev des = \case subtape (EIdx0 ext e1) sub - (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $ + (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ weakenExpr (WCopy WSink) e2) EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" @@ -971,10 +979,13 @@ drev des = \case (SEYes (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (ELet ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (EVar ext (STArr n (d2 t)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EJust ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) + (EVar ext (STArr n (d2 t)) IZ))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + (EVar ext (d2 (STArr n t)) IZ)) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e @@ -1010,13 +1021,17 @@ drev des = \case (SEYes (SEYes subtape)) (EVar ext at' IZ) sub - (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ - ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))) - (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) - (EZero ext t)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + (EMaybe ext + (zeroTup (subList (select SMerge des) sub)) + (ELet ext (EJust ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) + (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (EZero ext t))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) + (EVar ext (d2 at') IZ)) data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index e8ec0c9..7f49cef 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -20,7 +20,7 @@ type family D2 t where D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TArr n (D2 t) + D2 (TArr n t) = TMaybe (TArr n (D2 t)) D2 (TScal t) = D2s t type family D2s t where @@ -60,7 +60,7 @@ d2 STNil = STNil d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STArr n (d2 t) +d2 (STArr n t) = STMaybe (STArr n (d2 t)) d2 (STScal t) = case t of STI32 -> STNil STI64 -> STNil diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index a75fdb8..f843206 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -29,14 +29,14 @@ toTan typ primal der = case typ of (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t - | shapeSize (arrayShape der) == 0 -> - arrayMap (zeroTan t) primal - | arrayShape primal == arrayShape der -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t -> case der of + Nothing -> arrayMap (zeroTan t) primal + Just d + | arrayShape primal == arrayShape d -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Compile.hs b/src/Compile.hs index 09c3ed5..5501746 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} @@ -7,8 +9,10 @@ {-# LANGUAGE TypeApplications #-} module Compile (compile, debugCSource, debugRefc, emitChecks) where +import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Maybe import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS import Data.Bifunctor (first) @@ -36,10 +40,12 @@ import qualified Prelude import Array import AST -import AST.Pretty (ppTy) +import AST.Pretty (ppSTy) +import qualified CHAD.Types as CHAD import Compile.Exec import Data import Interpreter.Rep +import qualified Util.IdGen as IdGen -- In shape and index arrays, the innermost dimension is on the right (last index). @@ -188,62 +194,59 @@ printCExpr d = \case ,("/", (7, (7, 8))) ,("%", (7, (7, 8)))] -repTy :: Ty -> String -repTy (TScal st) = case st of - TI32 -> "int32_t" - TI64 -> "int64_t" - TF32 -> "float" - TF64 -> "double" - TBool -> "uint8_t" -repTy t = genStructName t - repSTy :: STy t -> String -repSTy = repTy . unSTy - -genStructName :: Ty -> String +repSTy (STScal st) = case st of + STI32 -> "int32_t" + STI64 -> "int64_t" + STF32 -> "float" + STF64 -> "double" + STBool -> "uint8_t" +repSTy t = genStructName t + +genStructName :: STy t -> String genStructName = \t -> "ty_" ++ gen t where -- all tags start with a letter, so the array mangling is unambiguous. - gen :: Ty -> String - gen TNil = "n" - gen (TPair a b) = 'P' : gen a ++ gen b - gen (TEither a b) = 'E' : gen a ++ gen b - gen (TMaybe t) = 'M' : gen t - gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t - gen (TScal st) = case st of - TI32 -> "i" - TI64 -> "j" - TF32 -> "f" - TF64 -> "d" - TBool -> "b" - gen (TAccum t) = 'C' : gen t + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen t -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the -- types in the C translation. -genStruct :: String -> Ty -> [StructDecl] +genStruct :: String -> STy t -> [StructDecl] genStruct name topty = case topty of - TNil -> + STNil -> [StructDecl name "" com] - TPair a b -> - [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com] - TEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " l; " ++ repTy b ++ " r; };") com] - TMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repTy t ++ " j;") com] - TArr n t -> + STPair a b -> + [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + STEither a b -> -- 0 -> l, 1 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + STMaybe t -> -- 0 -> nothing, 1 -> just + [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " xs[];") "" + [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" ,StructDecl name (name ++ "_buf *buf;") com] - TScal _ -> + STScal _ -> [] - TAccum t -> - [StructDecl name (repTy t ++ " ac;") com] + STAccum t -> + [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com] where - com = ppTy 0 topty + com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: Ty -> WriterT (Bag StructDecl) (State (Set String)) () +genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () genStructs ty = do let name = genStructName ty seen <- lift $ gets (name `Set.member`) @@ -255,19 +258,19 @@ genStructs ty = do -- twice (unnecessary because no recursive types, but y'know) lift $ modify (Set.insert name) - case ty of - TNil -> pure () - TPair a b -> genStructs a >> genStructs b - TEither a b -> genStructs a >> genStructs b - TMaybe t -> genStructs t - TArr _ t -> genStructs t - TScal _ -> pure () - TAccum t -> genStructs t + () <- case ty of + STNil -> pure () + STPair a b -> genStructs a >> genStructs b + STEither a b -> genStructs a >> genStructs b + STMaybe t -> genStructs t + STArr _ t -> genStructs t + STScal _ -> pure () + STAccum t -> genStructs (CHAD.d2 t) tell (BList (genStruct name ty)) genAllStructs :: Foldable t => t Ty -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty +genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty data CompState = CompState { csStructs :: Set Ty @@ -276,36 +279,48 @@ data CompState = CompState , csNextId :: Int } deriving (Show) -type CompM a = State CompState a +newtype CompM a = CompM (State CompState a) + deriving newtype (Functor, Applicative, Monad) + +runCompM :: CompM a -> (a, CompState) +runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) -genId :: CompM Int -genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +class Monad m => MonadNameGen m where genId :: m Int +instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +instance MonadNameGen IdGen.IdGen where genId = IdGen.genId +instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) -genName' :: String -> CompM String +genName' :: MonadNameGen m => String -> m String genName' "" = genName genName' prefix = (prefix ++) . show <$> genId -genName :: CompM String +genName :: MonadNameGen m => m String genName = genName' "x" +onlyIdGen :: IdGen.IdGen a -> CompM a +onlyIdGen m = CompM $ do + i1 <- gets csNextId + let (res, i2) = IdGen.runIdGen' i1 m + modify (\s -> s { csNextId = i2 }) + return res + emit :: Stmt -> CompM () -emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } +emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt } scope :: CompM a -> CompM (a, Bag Stmt) scope m = do - stmts <- state $ \s -> (csStmts s, s { csStmts = mempty }) + stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty }) res <- m - innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts }) + innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts }) return (res, innerStmts) emitStruct :: STy t -> CompM String -emitStruct ty = do - let ty' = unSTy ty - modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } - return (genStructName ty') +emitStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) } + return (genStructName ty) emitTLD :: String -> CompM () -emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } +emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) @@ -313,7 +328,7 @@ nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" compileToString :: SList STy env -> Ex env t -> String compileToString env expr = let args = nameEnv env - (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1) + (res, s) = runCompM (compile' args expr) structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) (arg_pairs, arg_metrics) = @@ -649,7 +664,7 @@ compile' env = \case x0name <- compileAssign "foldx0" env ex0 arrname <- compileAssign "foldarr" env earr - zeroRefcountCheck "fold1i" arrname + zeroRefcountCheck (typeOf earr) "fold1i" arrname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is @@ -694,7 +709,7 @@ compile' env = \case let STArr (SS n) t = typeOf e argname <- compileAssign "sumarg" env e - zeroRefcountCheck "sum1i" argname + zeroRefcountCheck (typeOf e) "sum1i" argname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, like EFold1Inner. @@ -737,7 +752,7 @@ compile' env = \case lenname <- compileAssign "replen" env elen argname <- compileAssign "reparg" env earg - zeroRefcountCheck "replicate1i" argname + zeroRefcountCheck (typeOf earg) "replicate1i" argname shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -767,7 +782,7 @@ compile' env = \case EIdx0 _ e -> do let STArr _ t = typeOf e arrname <- compileAssign "" env e - zeroRefcountCheck "idx0" arrname + zeroRefcountCheck (typeOf e) "idx0" arrname name <- genName emit $ SVarDecl True (repSTy t) name (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) @@ -779,7 +794,7 @@ compile' env = \case EIdx _ earr eidx -> do let STArr n t = typeOf earr arrname <- compileAssign "ixarr" env earr - zeroRefcountCheck "idx" arrname + zeroRefcountCheck (typeOf earr) "idx" arrname idxname <- if fromSNat n > 0 -- prevent an unused-varable warning then compileAssign "ixix" env eidx else return "" -- won't be used in this case @@ -803,7 +818,7 @@ compile' env = \case t = tTup (sreplicate n tIx) _ <- emitStruct t name <- compileAssign "" env e - zeroRefcountCheck "shape" name + zeroRefcountCheck (typeOf e) "shape" name resname <- genName emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) incrementVarAlways Decrement (typeOf e) name @@ -833,15 +848,15 @@ compile' env = \case actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 - zeroRefcountCheck "with" name1 + zeroRefcountCheck (typeOf e1) "with" name1 - mcopy <- copyForWriting t name1 + mcopy <- copyForWriting (CHAD.d2 t) name1 accname <- genName' "accum" emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)]) e2' <- compile' (Const accname `SCons` env) e2 - rettyname <- emitStruct (STPair (typeOf e2) t) + rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t)) return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")] EAccum _ t prj eidx eval eacc -> do @@ -1096,7 +1111,7 @@ shapeTupFromLitVars = \n -> go n . reverse compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do - let unary cop = return @(State CompState) $ CECall cop [e1] + let unary cop = return @CompM $ CECall cop [e1] let binary cop = do name <- genName emit $ SVarDecl True (repSTy (opt1 op)) name e1 @@ -1127,7 +1142,7 @@ compileOpGeneral op e1 = do compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr compileOpPair op e1 e2 = do - let binary cop = return @(State CompState) $ CEBinop e1 cop e2 + let binary cop = return @CompM $ CEBinop e1 cop e2 case op of OAdd _ -> binary "+" OMul _ -> binary "*" @@ -1153,7 +1168,7 @@ compileExtremum nameBase opName operator env e = do let STArr (SS n) t = typeOf e argname <- compileAssign (nameBase ++ "arg") env e - zeroRefcountCheck opName argname + zeroRefcountCheck (typeOf e) opName argname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is @@ -1209,7 +1224,7 @@ copyForWriting topty var = case topty of _ -> do name <- genName emit $ SVarDeclUninit (repSTy topty) name - emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) (stmts1 <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) @@ -1225,7 +1240,7 @@ copyForWriting topty var = case topty of Just e1' -> do name <- genName emit $ SVarDeclUninit (repSTy topty) name - emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) (stmts1 <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) @@ -1301,13 +1316,49 @@ copyForWriting topty var = case topty of STAccum _ -> error "Compile: Nested accumulators not supported" -zeroRefcountCheck :: String -> String -> CompM () -zeroRefcountCheck opname arrvar = - when emitChecks $ - emit $ SVerbatim $ - "if (__builtin_expect(" ++ arrvar ++ ".buf->refc == 0, 0)) { " ++ - "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ - "%p with refc=0\\n\", " ++ arrvar ++ ".buf); abort(); }" +zeroRefcountCheck :: STy t -> String -> String -> CompM () +zeroRefcountCheck toptyp opname topvar = + when emitChecks $ do + mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) + case mstmts of + Nothing -> return () + Just stmts -> forM_ stmts emit + where + -- | If this returns 'Nothing', no statements need to be generated for this type. + go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) + go STNil _ = empty + go (STPair a b) path = do + (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) + return (s1 <> s2) + go (STEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STMaybe a) path = do + ss <- go a (path++".j") + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty + go (STArr n a) path = do + ivar <- genName' "i" + ss <- go a (path++".buf->xs["++ivar++"]") + shszname <- genName' "shsz" + let s1 = SVerbatim $ + "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ + "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ + "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }" + let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) + let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss + return (BList [s1, s2, s3]) + go STScal{} _ = empty + go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + + combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) + combine (MaybeT a) (MaybeT b) = MaybeT $ do + x <- a + y <- b + return $ case (x, y) of + (Nothing, Nothing) -> Nothing + (Just x', Nothing) -> Just (x', mempty) + (Nothing, Just y') -> Just (mempty, y') + (Just x', Just y') -> Just (x', y') compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id diff --git a/src/Data.hs b/src/Data.hs index 60afdd0..e7b3148 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -11,6 +11,7 @@ module Data (module Data, (:~:)(Refl)) where import Data.Functor.Product +import Data.Some import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) @@ -85,6 +86,10 @@ unSNat :: SNat n -> Nat unSNat SZ = Z unSNat (SS n) = S (unSNat n) +reSNat :: Nat -> Some SNat +reSNat Z = Some SZ +reSNat (S n) | Some n' <- reSNat n = Some (SS n') + fromNat :: Nat -> Int fromNat Z = 0 fromNat (S m) = succ (fromNat m) diff --git a/src/Example.hs b/src/Example.hs index e234ff4..2c710a1 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -183,7 +183,7 @@ neuralGo = ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of - (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') + (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') _ -> undefined (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3cc7ae4..ddc3479 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -188,8 +188,7 @@ zeroD2 typ = case typ of STPair _ _ -> Nothing STEither _ _ -> Nothing STMaybe _ -> Nothing - STArr SZ t -> arrayUnit (zeroD2 t) - STArr n _ -> emptyArray n + STArr _ _ -> Nothing STScal sty -> case sty of STI32 -> () STI64 -> () @@ -215,13 +214,16 @@ addD2s typ a b = case typ of (Nothing, _) -> b (_, Nothing) -> a (Just x, Just y) -> Just (addD2s t x y) - STArr _ t -> - let sh1 = arrayShape a - sh2 = arrayShape b - in if | shapeSize sh1 == 0 -> b - | shapeSize sh2 == 0 -> a - | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i)) - | otherwise -> error "Plus of inconsistently shaped arrays" + STArr _ t -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just x, Just y) -> + let sh1 = arrayShape x + sh2 = arrayShape y + in if | shapeSize sh1 == 0 -> Just y + | shapeSize sh2 == 0 -> Just x + | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) + | otherwise -> error "Plus of inconsistently shaped arrays" STScal sty -> case sty of STI32 -> () STI64 -> () @@ -238,7 +240,7 @@ onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx v onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = - runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx + Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) withAccum t _ initval f = AcM $ do @@ -253,7 +255,7 @@ newAcZero = \case STPair{} -> newIORef Nothing STEither{} -> newIORef Nothing STMaybe _ -> newIORef Nothing - STArr n _ -> newIORef (emptyArray n) + STArr _ _ -> newIORef Nothing STScal sty -> case sty of STI32 -> return () STI64 -> return () @@ -268,7 +270,7 @@ newAcSparse typ prj idx val = case (typ, prj) of (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val - (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val + (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val (STScal sty, SAPHere) -> case sty of STI32 -> return () STI64 -> return () @@ -286,7 +288,7 @@ newAcSparse typ prj idx val = case (typ, prj) of (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val + (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val (STAccum{}, _) -> error "Accumulators not allowed in source program" @@ -309,7 +311,7 @@ readAcSparse typ val = case typ of STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val STMaybe t -> traverse (readAcSparse t) =<< readIORef val - STArr _ t -> traverse (readAcSparse t) =<< readIORef val + STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val STScal sty -> case sty of STI32 -> return () STI64 -> return () @@ -360,32 +362,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of (\ac -> accumAddSparse t1 prj' ac idx val) (STArr _ t1, SAPHere) -> - let add ac = forM_ [0 .. arraySize ac - 1] $ \i -> - unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i) - in if arraySize val == 0 - then return () - else AcM $ join $ atomicModifyIORef' ref $ \ac -> - if arraySize ac == 0 - then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val - join $ atomicModifyIORef' ref $ \ac' -> - if arraySize ac == 0 - then (newac, return ()) - else (ac', add ac')) - else (ac, add ac) + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref + (arrayMapM (newAcSparse t1 SAPHere ()) val') + (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> + accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) (STArr n t1, SAPArrIdx prj' _) -> let ((arrindex', arrsh'), idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = unTupRepIdx ShNil ShCons n arrsh' linindex = toLinearIndex arrsh arrindex - add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val - in AcM $ join $ atomicModifyIORef' ref $ \ac -> - if arraySize ac == 0 - then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx - join $ atomicModifyIORef' ref $ \ac' -> - if arraySize ac == 0 - then (newac, return ()) - else (ac', add ac')) - else (ac, add ac) + in realiseMaybeSparse ref + (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) + (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) (STScal sty, SAPHere) -> AcM $ case sty of STI32 -> return () diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index f84f4e7..be2a4cc 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -29,8 +29,7 @@ type family RepAc t where RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b)) RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) RepAc (TMaybe t) = IORef (Maybe (RepAc t)) - -- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero. - RepAc (TArr n t) = IORef (Array n (RepAc t)) -- empty array is zero + RepAc (TArr n t) = IORef (Maybe (Array n (RepAc t))) RepAc (TScal sty) = RepAcScal sty RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") diff --git a/src/Util/IdGen.hs b/src/Util/IdGen.hs index fcfb6e7..3f6611d 100644 --- a/src/Util/IdGen.hs +++ b/src/Util/IdGen.hs @@ -2,11 +2,12 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Util.IdGen where +import Control.Monad.Fix import Control.Monad.Trans.State.Strict newtype IdGen a = IdGen (State Int a) - deriving newtype (Functor, Applicative, Monad) + deriving newtype (Functor, Applicative, Monad, MonadFix) genId :: IdGen Int genId = IdGen (state (\i -> (i, i + 1))) |