summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
commita00234388d1b4e14481067d030bf90031258b756 (patch)
tree501b6778fc5779ce220aba1e22f56ae60f68d970 /src
parent7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff)
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs22
-rw-r--r--src/AST/Types.hs51
-rw-r--r--src/AST/UnMonoid.hs16
-rw-r--r--src/CHAD.hs75
-rw-r--r--src/CHAD/Types.hs4
-rw-r--r--src/CHAD/Types/ToTan.hs16
-rw-r--r--src/Compile.hs221
-rw-r--r--src/Data.hs5
-rw-r--r--src/Example.hs2
-rw-r--r--src/Interpreter.hs61
-rw-r--r--src/Interpreter/Rep.hs3
-rw-r--r--src/Util/IdGen.hs3
12 files changed, 285 insertions, 194 deletions
diff --git a/src/AST.hs b/src/AST.hs
index a4898c0..c8377de 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -246,28 +246,6 @@ extOf = \case
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
-unSTy :: STy t -> Ty
-unSTy = \case
- STNil -> TNil
- STPair a b -> TPair (unSTy a) (unSTy b)
- STEither a b -> TEither (unSTy a) (unSTy b)
- STMaybe t -> TMaybe (unSTy t)
- STArr n t -> TArr (unSNat n) (unSTy t)
- STScal t -> TScal (unSScalTy t)
- STAccum t -> TAccum (unSTy t)
-
-unSEnv :: SList STy env -> [Ty]
-unSEnv SNil = []
-unSEnv (SCons t l) = unSTy t : unSEnv l
-
-unSScalTy :: SScalTy t -> ScalTy
-unSScalTy = \case
- STI32 -> TI32
- STI64 -> TI64
- STF32 -> TF32
- STF64 -> TF64
- STBool -> TBool
-
subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
IS i -> EVar x t i
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index 0b41671..217b2f5 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -8,7 +9,9 @@
module AST.Types where
import Data.Int (Int32, Int64)
+import Data.GADT.Show
import Data.Kind (Type)
+import Data.Some
import Data.Type.Equality
import Data
@@ -54,6 +57,8 @@ instance TestEquality STy where
testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
testEquality STAccum{} _ = Nothing
+instance GShow STy where gshowsPrec = defaultGshowsPrec
+
data SScalTy t where
STI32 :: SScalTy TI32
STI64 :: SScalTy TI64
@@ -70,6 +75,8 @@ instance TestEquality SScalTy where
testEquality STBool STBool = Just Refl
testEquality _ _ = Nothing
+instance GShow SScalTy where gshowsPrec = defaultGshowsPrec
+
scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
scalRepIsShow STI32 = Dict
scalRepIsShow STI64 = Dict
@@ -82,6 +89,50 @@ type TIx = TScal TI64
tIx :: STy TIx
tIx = STScal STI64
+unSTy :: STy t -> Ty
+unSTy = \case
+ STNil -> TNil
+ STPair a b -> TPair (unSTy a) (unSTy b)
+ STEither a b -> TEither (unSTy a) (unSTy b)
+ STMaybe t -> TMaybe (unSTy t)
+ STArr n t -> TArr (unSNat n) (unSTy t)
+ STScal t -> TScal (unSScalTy t)
+ STAccum t -> TAccum (unSTy t)
+
+unSEnv :: SList STy env -> [Ty]
+unSEnv SNil = []
+unSEnv (SCons t l) = unSTy t : unSEnv l
+
+unSScalTy :: SScalTy t -> ScalTy
+unSScalTy = \case
+ STI32 -> TI32
+ STI64 -> TI64
+ STF32 -> TF32
+ STF64 -> TF64
+ STBool -> TBool
+
+reSTy :: Ty -> Some STy
+reSTy = \case
+ TNil -> Some STNil
+ TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b'
+ TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b'
+ TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t'
+ TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t'
+ TScal t | Some t' <- reSScalTy t -> Some $ STScal t'
+ TAccum t | Some t' <- reSTy t -> Some $ STAccum t'
+
+reSEnv :: [Ty] -> Some (SList STy)
+reSEnv [] = Some SNil
+reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env)
+
+reSScalTy :: ScalTy -> Some SScalTy
+reSScalTy = \case
+ TI32 -> Some STI32
+ TI64 -> Some STI64
+ TF32 -> Some STF32
+ TF64 -> Some STF64
+ TBool -> Some STBool
+
type family ScalRep t where
ScalRep TI32 = Int32
ScalRep TI64 = Int64
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index b30f7a0..0da1afc 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -51,8 +51,8 @@ zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr SZ t) = EUnit ext (zero t)
-zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty")
+zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t))
+zero (STArr n t) = ENothing ext (STArr n (d2 t))
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -84,8 +84,7 @@ plus (STMaybe t) a b =
plusSparse (d2 t) a b $
plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
plus (STArr n t) a b =
- ELet ext a $
- ELet ext (weakenExpr WSink b) $
+ plusSparse (STArr n (d2 t)) a b $
eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
(EVar ext (STArr n (d2 t)) IZ)
(eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
@@ -131,7 +130,8 @@ onehot typ topprj idx arg = case (typ, topprj) of
(STArr n t1, SAPArrIdx prj _) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
- EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (zero t1)
+ EJust ext $
+ EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (zero t1)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index a5a5719..be308cd 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -846,18 +846,18 @@ drev des = \case
(emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
(subenvCompose subMergeUsed proSub)
- (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ)))
+ (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
+ EMaybe ext
(zeroTup envPro)
(ESnd ext $
uninvertTup (d2e envPro) (STArr ndim STNil) $
makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $
-- the cotangent for this element
ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
(EVar ext shty IZ)) $
-- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ)))
(EVar ext shty (IS IZ))) $
let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
in letBinds rebinds $
@@ -865,17 +865,19 @@ drev des = \case
&. #pro (d2ace envPro)
&. #etape (subList (bindingsBinds e0) subtapeE)
&. #prerebinds prerebinds
- &. #tape (tapety `SCons` SNil)
- &. #ix (shty `SCons` SNil)
- &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
- &. #tapearr (STArr ndim tapety `SCons` SNil)
- &. #sh (shty `SCons` SNil)
+ &. #tape (auto1 @(Tape e_tape))
+ &. #ix (auto1 @shty)
+ &. #darr (auto1 @(TArr ndim (D2 eltty)))
+ &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty))))
+ &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
+ &. #sh (auto1 @shty)
&. #d2acUsed (d2ace (select SAccum usedDes))
&. #d2acEnv (d2ace (select SAccum des)))
(#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv)
.> wPro (subList (bindingsBinds e0) subtapeE))
- e2))
+ e2)
+ (EVar ext (d2 (STArr ndim eltty)) IZ))
}}
EUnit _ e
@@ -884,8 +886,11 @@ drev des = \case
subtape
(EUnit ext e1)
sub
- (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
- weakenExpr (WCopy WSink) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+ (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ))
EReplicate1Inner _ en e
-- We're allowed to ignore en2 here because the output of 'ei' is discrete.
@@ -896,11 +901,14 @@ drev des = \case
subtape
(EReplicate1Inner ext en1 e1)
sub
- (ELet ext (EFold1Inner ext Commut
- (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (EZero ext eltty)
- (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $
- weakenExpr (WCopy WSink) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EJust ext (EFold1Inner ext Commut
+ (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero ext eltty)
+ (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+ (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
EIdx0 _ e
| Ret e0 subtape e1 sub e2 <- drev des e
@@ -909,7 +917,7 @@ drev des = \case
subtape
(EIdx0 ext e1)
sub
- (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $
+ (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $
weakenExpr (WCopy WSink) e2)
EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
@@ -971,10 +979,13 @@ drev des = \case
(SEYes (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (ELet ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
- (EVar ext (STArr n (d2 t)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EJust ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
+ (EVar ext (STArr n (d2 t)) IZ))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ (EVar ext (d2 (STArr n t)) IZ))
EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
@@ -1010,13 +1021,17 @@ drev des = \case
(SEYes (SEYes subtape))
(EVar ext at' IZ)
sub
- (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
- ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))))
- (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))
- (EZero ext t)) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ (EMaybe ext
+ (zeroTup (subList (select SMerge des) sub))
+ (ELet ext (EJust ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
+ (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (EZero ext t))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
+ (EVar ext (d2 at') IZ))
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index e8ec0c9..7f49cef 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -20,7 +20,7 @@ type family D2 t where
D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
D2 (TMaybe t) = TMaybe (D2 t)
- D2 (TArr n t) = TArr n (D2 t)
+ D2 (TArr n t) = TMaybe (TArr n (D2 t))
D2 (TScal t) = D2s t
type family D2s t where
@@ -60,7 +60,7 @@ d2 STNil = STNil
d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
d2 (STMaybe t) = STMaybe (d2 t)
-d2 (STArr n t) = STArr n (d2 t)
+d2 (STArr n t) = STMaybe (STArr n (d2 t))
d2 (STScal t) = case t of
STI32 -> STNil
STI64 -> STNil
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index a75fdb8..f843206 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -29,14 +29,14 @@ toTan typ primal der = case typ of
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t
- | shapeSize (arrayShape der) == 0 ->
- arrayMap (zeroTan t) primal
- | arrayShape primal == arrayShape der ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
+ STArr _ t -> case der of
+ Nothing -> arrayMap (zeroTan t) primal
+ Just d
+ | arrayShape primal == arrayShape d ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
diff --git a/src/Compile.hs b/src/Compile.hs
index 09c3ed5..5501746 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,5 +1,7 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
@@ -7,8 +9,10 @@
{-# LANGUAGE TypeApplications #-}
module Compile (compile, debugCSource, debugRefc, emitChecks) where
+import Control.Applicative (empty)
import Control.Monad (forM_, when, replicateM)
import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.Maybe
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Writer.CPS
import Data.Bifunctor (first)
@@ -36,10 +40,12 @@ import qualified Prelude
import Array
import AST
-import AST.Pretty (ppTy)
+import AST.Pretty (ppSTy)
+import qualified CHAD.Types as CHAD
import Compile.Exec
import Data
import Interpreter.Rep
+import qualified Util.IdGen as IdGen
-- In shape and index arrays, the innermost dimension is on the right (last index).
@@ -188,62 +194,59 @@ printCExpr d = \case
,("/", (7, (7, 8)))
,("%", (7, (7, 8)))]
-repTy :: Ty -> String
-repTy (TScal st) = case st of
- TI32 -> "int32_t"
- TI64 -> "int64_t"
- TF32 -> "float"
- TF64 -> "double"
- TBool -> "uint8_t"
-repTy t = genStructName t
-
repSTy :: STy t -> String
-repSTy = repTy . unSTy
-
-genStructName :: Ty -> String
+repSTy (STScal st) = case st of
+ STI32 -> "int32_t"
+ STI64 -> "int64_t"
+ STF32 -> "float"
+ STF64 -> "double"
+ STBool -> "uint8_t"
+repSTy t = genStructName t
+
+genStructName :: STy t -> String
genStructName = \t -> "ty_" ++ gen t where
-- all tags start with a letter, so the array mangling is unambiguous.
- gen :: Ty -> String
- gen TNil = "n"
- gen (TPair a b) = 'P' : gen a ++ gen b
- gen (TEither a b) = 'E' : gen a ++ gen b
- gen (TMaybe t) = 'M' : gen t
- gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t
- gen (TScal st) = case st of
- TI32 -> "i"
- TI64 -> "j"
- TF32 -> "f"
- TF64 -> "d"
- TBool -> "b"
- gen (TAccum t) = 'C' : gen t
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen t
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
-- types in the C translation.
-genStruct :: String -> Ty -> [StructDecl]
+genStruct :: String -> STy t -> [StructDecl]
genStruct name topty = case topty of
- TNil ->
+ STNil ->
[StructDecl name "" com]
- TPair a b ->
- [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com]
- TEither a b -> -- 0 -> l, 1 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " l; " ++ repTy b ++ " r; };") com]
- TMaybe t -> -- 0 -> nothing, 1 -> just
- [StructDecl name ("uint8_t tag; " ++ repTy t ++ " j;") com]
- TArr n t ->
+ STPair a b ->
+ [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ STEither a b -> -- 0 -> l, 1 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ STMaybe t -> -- 0 -> nothing, 1 -> just
+ [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " xs[];") ""
+ [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
,StructDecl name (name ++ "_buf *buf;") com]
- TScal _ ->
+ STScal _ ->
[]
- TAccum t ->
- [StructDecl name (repTy t ++ " ac;") com]
+ STAccum t ->
+ [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com]
where
- com = ppTy 0 topty
+ com = ppSTy 0 topty
-- State: already-generated (skippable) struct names
-- Writer: the structs in declaration order
-genStructs :: Ty -> WriterT (Bag StructDecl) (State (Set String)) ()
+genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) ()
genStructs ty = do
let name = genStructName ty
seen <- lift $ gets (name `Set.member`)
@@ -255,19 +258,19 @@ genStructs ty = do
-- twice (unnecessary because no recursive types, but y'know)
lift $ modify (Set.insert name)
- case ty of
- TNil -> pure ()
- TPair a b -> genStructs a >> genStructs b
- TEither a b -> genStructs a >> genStructs b
- TMaybe t -> genStructs t
- TArr _ t -> genStructs t
- TScal _ -> pure ()
- TAccum t -> genStructs t
+ () <- case ty of
+ STNil -> pure ()
+ STPair a b -> genStructs a >> genStructs b
+ STEither a b -> genStructs a >> genStructs b
+ STMaybe t -> genStructs t
+ STArr _ t -> genStructs t
+ STScal _ -> pure ()
+ STAccum t -> genStructs (CHAD.d2 t)
tell (BList (genStruct name ty))
genAllStructs :: Foldable t => t Ty -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty
+genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty
data CompState = CompState
{ csStructs :: Set Ty
@@ -276,36 +279,48 @@ data CompState = CompState
, csNextId :: Int }
deriving (Show)
-type CompM a = State CompState a
+newtype CompM a = CompM (State CompState a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runCompM :: CompM a -> (a, CompState)
+runCompM (CompM m) = runState m (CompState mempty mempty mempty 1)
-genId :: CompM Int
-genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+class Monad m => MonadNameGen m where genId :: m Int
+instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+instance MonadNameGen IdGen.IdGen where genId = IdGen.genId
+instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId)
-genName' :: String -> CompM String
+genName' :: MonadNameGen m => String -> m String
genName' "" = genName
genName' prefix = (prefix ++) . show <$> genId
-genName :: CompM String
+genName :: MonadNameGen m => m String
genName = genName' "x"
+onlyIdGen :: IdGen.IdGen a -> CompM a
+onlyIdGen m = CompM $ do
+ i1 <- gets csNextId
+ let (res, i2) = IdGen.runIdGen' i1 m
+ modify (\s -> s { csNextId = i2 })
+ return res
+
emit :: Stmt -> CompM ()
-emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt }
+emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }
scope :: CompM a -> CompM (a, Bag Stmt)
scope m = do
- stmts <- state $ \s -> (csStmts s, s { csStmts = mempty })
+ stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })
res <- m
- innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts })
+ innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })
return (res, innerStmts)
emitStruct :: STy t -> CompM String
-emitStruct ty = do
- let ty' = unSTy ty
- modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) }
- return (genStructName ty')
+emitStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) }
+ return (genStructName ty)
emitTLD :: String -> CompM ()
-emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
+emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
@@ -313,7 +328,7 @@ nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg"
compileToString :: SList STy env -> Ex env t -> String
compileToString env expr =
let args = nameEnv env
- (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1)
+ (res, s) = runCompM (compile' args expr)
structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
(arg_pairs, arg_metrics) =
@@ -649,7 +664,7 @@ compile' env = \case
x0name <- compileAssign "foldx0" env ex0
arrname <- compileAssign "foldarr" env earr
- zeroRefcountCheck "fold1i" arrname
+ zeroRefcountCheck (typeOf earr) "fold1i" arrname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
@@ -694,7 +709,7 @@ compile' env = \case
let STArr (SS n) t = typeOf e
argname <- compileAssign "sumarg" env e
- zeroRefcountCheck "sum1i" argname
+ zeroRefcountCheck (typeOf e) "sum1i" argname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, like EFold1Inner.
@@ -737,7 +752,7 @@ compile' env = \case
lenname <- compileAssign "replen" env elen
argname <- compileAssign "reparg" env earg
- zeroRefcountCheck "replicate1i" argname
+ zeroRefcountCheck (typeOf earg) "replicate1i" argname
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
@@ -767,7 +782,7 @@ compile' env = \case
EIdx0 _ e -> do
let STArr _ t = typeOf e
arrname <- compileAssign "" env e
- zeroRefcountCheck "idx0" arrname
+ zeroRefcountCheck (typeOf e) "idx0" arrname
name <- genName
emit $ SVarDecl True (repSTy t) name
(CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
@@ -779,7 +794,7 @@ compile' env = \case
EIdx _ earr eidx -> do
let STArr n t = typeOf earr
arrname <- compileAssign "ixarr" env earr
- zeroRefcountCheck "idx" arrname
+ zeroRefcountCheck (typeOf earr) "idx" arrname
idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
then compileAssign "ixix" env eidx
else return "" -- won't be used in this case
@@ -803,7 +818,7 @@ compile' env = \case
t = tTup (sreplicate n tIx)
_ <- emitStruct t
name <- compileAssign "" env e
- zeroRefcountCheck "shape" name
+ zeroRefcountCheck (typeOf e) "shape" name
resname <- genName
emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
incrementVarAlways Decrement (typeOf e) name
@@ -833,15 +848,15 @@ compile' env = \case
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
- zeroRefcountCheck "with" name1
+ zeroRefcountCheck (typeOf e1) "with" name1
- mcopy <- copyForWriting t name1
+ mcopy <- copyForWriting (CHAD.d2 t) name1
accname <- genName' "accum"
emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)])
e2' <- compile' (Const accname `SCons` env) e2
- rettyname <- emitStruct (STPair (typeOf e2) t)
+ rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t))
return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")]
EAccum _ t prj eidx eval eacc -> do
@@ -1096,7 +1111,7 @@ shapeTupFromLitVars = \n -> go n . reverse
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
- let unary cop = return @(State CompState) $ CECall cop [e1]
+ let unary cop = return @CompM $ CECall cop [e1]
let binary cop = do
name <- genName
emit $ SVarDecl True (repSTy (opt1 op)) name e1
@@ -1127,7 +1142,7 @@ compileOpGeneral op e1 = do
compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
compileOpPair op e1 e2 = do
- let binary cop = return @(State CompState) $ CEBinop e1 cop e2
+ let binary cop = return @CompM $ CEBinop e1 cop e2
case op of
OAdd _ -> binary "+"
OMul _ -> binary "*"
@@ -1153,7 +1168,7 @@ compileExtremum nameBase opName operator env e = do
let STArr (SS n) t = typeOf e
argname <- compileAssign (nameBase ++ "arg") env e
- zeroRefcountCheck opName argname
+ zeroRefcountCheck (typeOf e) opName argname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
@@ -1209,7 +1224,7 @@ copyForWriting topty var = case topty of
_ -> do
name <- genName
emit $ SVarDeclUninit (repSTy topty) name
- emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
(stmts1
<> pure (SAsg name (CEStruct (repSTy topty)
[("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
@@ -1225,7 +1240,7 @@ copyForWriting topty var = case topty of
Just e1' -> do
name <- genName
emit $ SVarDeclUninit (repSTy topty) name
- emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
(pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")])))
(stmts1
<> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')])))
@@ -1301,13 +1316,49 @@ copyForWriting topty var = case topty of
STAccum _ -> error "Compile: Nested accumulators not supported"
-zeroRefcountCheck :: String -> String -> CompM ()
-zeroRefcountCheck opname arrvar =
- when emitChecks $
- emit $ SVerbatim $
- "if (__builtin_expect(" ++ arrvar ++ ".buf->refc == 0, 0)) { " ++
- "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
- "%p with refc=0\\n\", " ++ arrvar ++ ".buf); abort(); }"
+zeroRefcountCheck :: STy t -> String -> String -> CompM ()
+zeroRefcountCheck toptyp opname topvar =
+ when emitChecks $ do
+ mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar)
+ case mstmts of
+ Nothing -> return ()
+ Just stmts -> forM_ stmts emit
+ where
+ -- | If this returns 'Nothing', no statements need to be generated for this type.
+ go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt)
+ go STNil _ = empty
+ go (STPair a b) path = do
+ (s1, s2) <- combine (go a (path++".a")) (go b (path++".b"))
+ return (s1 <> s2)
+ go (STEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STMaybe a) path = do
+ ss <- go a (path++".j")
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
+ go (STArr n a) path = do
+ ivar <- genName' "i"
+ ss <- go a (path++".buf->xs["++ivar++"]")
+ shszname <- genName' "shsz"
+ let s1 = SVerbatim $
+ "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
+ "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
+ "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }"
+ let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
+ let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
+ return (BList [s1, s2, s3])
+ go STScal{} _ = empty
+ go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
+
+ combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
+ combine (MaybeT a) (MaybeT b) = MaybeT $ do
+ x <- a
+ y <- b
+ return $ case (x, y) of
+ (Nothing, Nothing) -> Nothing
+ (Just x', Nothing) -> Just (x', mempty)
+ (Nothing, Just y') -> Just (mempty, y')
+ (Just x', Just y') -> Just (x', y')
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
diff --git a/src/Data.hs b/src/Data.hs
index 60afdd0..e7b3148 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -11,6 +11,7 @@
module Data (module Data, (:~:)(Refl)) where
import Data.Functor.Product
+import Data.Some
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
@@ -85,6 +86,10 @@ unSNat :: SNat n -> Nat
unSNat SZ = Z
unSNat (SS n) = S (unSNat n)
+reSNat :: Nat -> Some SNat
+reSNat Z = Some SZ
+reSNat (S n) | Some n' <- reSNat n = Some (SS n')
+
fromNat :: Nat -> Int
fromNat Z = 0
fromNat (S m) = succ (fromNat m)
diff --git a/src/Example.hs b/src/Example.hs
index e234ff4..2c710a1 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -183,7 +183,7 @@ neuralGo =
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
(primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
- (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
+ (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
_ -> undefined
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
in trace (ppExpr knownEnv revderiv) $
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 3cc7ae4..ddc3479 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -188,8 +188,7 @@ zeroD2 typ = case typ of
STPair _ _ -> Nothing
STEither _ _ -> Nothing
STMaybe _ -> Nothing
- STArr SZ t -> arrayUnit (zeroD2 t)
- STArr n _ -> emptyArray n
+ STArr _ _ -> Nothing
STScal sty -> case sty of
STI32 -> ()
STI64 -> ()
@@ -215,13 +214,16 @@ addD2s typ a b = case typ of
(Nothing, _) -> b
(_, Nothing) -> a
(Just x, Just y) -> Just (addD2s t x y)
- STArr _ t ->
- let sh1 = arrayShape a
- sh2 = arrayShape b
- in if | shapeSize sh1 == 0 -> b
- | shapeSize sh2 == 0 -> a
- | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i))
- | otherwise -> error "Plus of inconsistently shaped arrays"
+ STArr _ t -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just x, Just y) ->
+ let sh1 = arrayShape x
+ sh2 = arrayShape y
+ in if | shapeSize sh1 == 0 -> Just y
+ | shapeSize sh2 == 0 -> Just x
+ | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i))
+ | otherwise -> error "Plus of inconsistently shaped arrays"
STScal sty -> case sty of
STI32 -> ()
STI64 -> ()
@@ -238,7 +240,7 @@ onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx v
onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))
onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)
onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
- runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
+ Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
withAccum t _ initval f = AcM $ do
@@ -253,7 +255,7 @@ newAcZero = \case
STPair{} -> newIORef Nothing
STEither{} -> newIORef Nothing
STMaybe _ -> newIORef Nothing
- STArr n _ -> newIORef (emptyArray n)
+ STArr _ _ -> newIORef Nothing
STScal sty -> case sty of
STI32 -> return ()
STI64 -> return ()
@@ -268,7 +270,7 @@ newAcSparse typ prj idx val = case (typ, prj) of
(STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
(STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
(STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
- (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
+ (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val
(STScal sty, SAPHere) -> case sty of
STI32 -> return ()
STI64 -> return ()
@@ -286,7 +288,7 @@ newAcSparse typ prj idx val = case (typ, prj) of
(STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
- (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val
+ (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val
(STAccum{}, _) -> error "Accumulators not allowed in source program"
@@ -309,7 +311,7 @@ readAcSparse typ val = case typ of
STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
STMaybe t -> traverse (readAcSparse t) =<< readIORef val
- STArr _ t -> traverse (readAcSparse t) =<< readIORef val
+ STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val
STScal sty -> case sty of
STI32 -> return ()
STI64 -> return ()
@@ -360,32 +362,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of
(\ac -> accumAddSparse t1 prj' ac idx val)
(STArr _ t1, SAPHere) ->
- let add ac = forM_ [0 .. arraySize ac - 1] $ \i ->
- unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i)
- in if arraySize val == 0
- then return ()
- else AcM $ join $ atomicModifyIORef' ref $ \ac ->
- if arraySize ac == 0
- then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val
- join $ atomicModifyIORef' ref $ \ac' ->
- if arraySize ac == 0
- then (newac, return ())
- else (ac', add ac'))
- else (ac, add ac)
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ realiseMaybeSparse ref
+ (arrayMapM (newAcSparse t1 SAPHere ()) val')
+ (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
+ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
(STArr n t1, SAPArrIdx prj' _) ->
let ((arrindex', arrsh'), idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = unTupRepIdx ShNil ShCons n arrsh'
linindex = toLinearIndex arrsh arrindex
- add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val
- in AcM $ join $ atomicModifyIORef' ref $ \ac ->
- if arraySize ac == 0
- then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx
- join $ atomicModifyIORef' ref $ \ac' ->
- if arraySize ac == 0
- then (newac, return ())
- else (ac', add ac'))
- else (ac, add ac)
+ in realiseMaybeSparse ref
+ (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx)
+ (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)
(STScal sty, SAPHere) -> AcM $ case sty of
STI32 -> return ()
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index f84f4e7..be2a4cc 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -29,8 +29,7 @@ type family RepAc t where
RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b))
RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b)))
RepAc (TMaybe t) = IORef (Maybe (RepAc t))
- -- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero.
- RepAc (TArr n t) = IORef (Array n (RepAc t)) -- empty array is zero
+ RepAc (TArr n t) = IORef (Maybe (Array n (RepAc t)))
RepAc (TScal sty) = RepAcScal sty
RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")
diff --git a/src/Util/IdGen.hs b/src/Util/IdGen.hs
index fcfb6e7..3f6611d 100644
--- a/src/Util/IdGen.hs
+++ b/src/Util/IdGen.hs
@@ -2,11 +2,12 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Util.IdGen where
+import Control.Monad.Fix
import Control.Monad.Trans.State.Strict
newtype IdGen a = IdGen (State Int a)
- deriving newtype (Functor, Applicative, Monad)
+ deriving newtype (Functor, Applicative, Monad, MonadFix)
genId :: IdGen Int
genId = IdGen (state (\i -> (i, i + 1)))