diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 | 
| commit | a00234388d1b4e14481067d030bf90031258b756 (patch) | |
| tree | 501b6778fc5779ce220aba1e22f56ae60f68d970 | |
| parent | 7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff) | |
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
| -rw-r--r-- | src/AST.hs | 22 | ||||
| -rw-r--r-- | src/AST/Types.hs | 51 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 16 | ||||
| -rw-r--r-- | src/CHAD.hs | 75 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/Types/ToTan.hs | 16 | ||||
| -rw-r--r-- | src/Compile.hs | 219 | ||||
| -rw-r--r-- | src/Data.hs | 5 | ||||
| -rw-r--r-- | src/Example.hs | 2 | ||||
| -rw-r--r-- | src/Interpreter.hs | 61 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 3 | ||||
| -rw-r--r-- | src/Util/IdGen.hs | 3 | 
12 files changed, 284 insertions, 193 deletions
| @@ -246,28 +246,6 @@ extOf = \case    EOneHot x _ _ _ _ -> x    EError x _ _ -> x -unSTy :: STy t -> Ty -unSTy = \case -  STNil -> TNil -  STPair a b -> TPair (unSTy a) (unSTy b) -  STEither a b -> TEither (unSTy a) (unSTy b) -  STMaybe t -> TMaybe (unSTy t) -  STArr n t -> TArr (unSNat n) (unSTy t) -  STScal t -> TScal (unSScalTy t) -  STAccum t -> TAccum (unSTy t) - -unSEnv :: SList STy env -> [Ty] -unSEnv SNil = [] -unSEnv (SCons t l) = unSTy t : unSEnv l - -unSScalTy :: SScalTy t -> ScalTy -unSScalTy = \case -  STI32 -> TI32 -  STI64 -> TI64 -  STF32 -> TF32 -  STF64 -> TF64 -  STBool -> TBool -  subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t  subst1 repl = subst $ \x t -> \case IZ -> repl                                      IS i -> EVar x t i diff --git a/src/AST/Types.hs b/src/AST/Types.hs index 0b41671..217b2f5 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -1,5 +1,6 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} @@ -8,7 +9,9 @@  module AST.Types where  import Data.Int (Int32, Int64) +import Data.GADT.Show  import Data.Kind (Type) +import Data.Some  import Data.Type.Equality  import Data @@ -54,6 +57,8 @@ instance TestEquality STy where    testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl    testEquality STAccum{} _ = Nothing +instance GShow STy where gshowsPrec = defaultGshowsPrec +  data SScalTy t where    STI32 :: SScalTy TI32    STI64 :: SScalTy TI64 @@ -70,6 +75,8 @@ instance TestEquality SScalTy where    testEquality STBool STBool = Just Refl    testEquality _ _ = Nothing +instance GShow SScalTy where gshowsPrec = defaultGshowsPrec +  scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))  scalRepIsShow STI32 = Dict  scalRepIsShow STI64 = Dict @@ -82,6 +89,50 @@ type TIx = TScal TI64  tIx :: STy TIx  tIx = STScal STI64 +unSTy :: STy t -> Ty +unSTy = \case +  STNil -> TNil +  STPair a b -> TPair (unSTy a) (unSTy b) +  STEither a b -> TEither (unSTy a) (unSTy b) +  STMaybe t -> TMaybe (unSTy t) +  STArr n t -> TArr (unSNat n) (unSTy t) +  STScal t -> TScal (unSScalTy t) +  STAccum t -> TAccum (unSTy t) + +unSEnv :: SList STy env -> [Ty] +unSEnv SNil = [] +unSEnv (SCons t l) = unSTy t : unSEnv l + +unSScalTy :: SScalTy t -> ScalTy +unSScalTy = \case +  STI32 -> TI32 +  STI64 -> TI64 +  STF32 -> TF32 +  STF64 -> TF64 +  STBool -> TBool + +reSTy :: Ty -> Some STy +reSTy = \case +  TNil -> Some STNil +  TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b' +  TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b' +  TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t' +  TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t' +  TScal t | Some t' <- reSScalTy t -> Some $ STScal t' +  TAccum t | Some t' <- reSTy t -> Some $ STAccum t' + +reSEnv :: [Ty] -> Some (SList STy) +reSEnv [] = Some SNil +reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env) + +reSScalTy :: ScalTy -> Some SScalTy +reSScalTy = \case +  TI32 -> Some STI32 +  TI64 -> Some STI64 +  TF32 -> Some STF32 +  TF64 -> Some STF64 +  TBool -> Some STBool +  type family ScalRep t where    ScalRep TI32 = Int32    ScalRep TI64 = Int64 diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index b30f7a0..0da1afc 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -51,8 +51,8 @@ zero STNil = ENil ext  zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))  zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))  zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr SZ t) = EUnit ext (zero t) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty") +zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t)) +zero (STArr n t) = ENothing ext (STArr n (d2 t))  zero (STScal t) = case t of    STI32 -> ENil ext    STI64 -> ENil ext @@ -84,8 +84,7 @@ plus (STMaybe t) a b =    plusSparse (d2 t) a b $      plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)  plus (STArr n t) a b = -  ELet ext a $ -  ELet ext (weakenExpr WSink b) $ +  plusSparse (STArr n (d2 t)) a b $      eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))          (EVar ext (STArr n (d2 t)) IZ)          (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) @@ -131,7 +130,8 @@ onehot typ topprj idx arg = case (typ, topprj) of    (STArr n t1, SAPArrIdx prj _) ->      let tidx = tTup (sreplicate n tIx)      in ELet ext idx $ -         EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ -           eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) -             (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) -             (zero t1) +         EJust ext $ +           EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ +             eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) +               (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) +               (zero t1) diff --git a/src/CHAD.hs b/src/CHAD.hs index a5a5719..be308cd 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -846,18 +846,18 @@ drev des = \case          (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))                (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))          (subenvCompose subMergeUsed proSub) -        (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in -         eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ))) +        (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in +         EMaybe ext             (zeroTup envPro)             (ESnd ext $                uninvertTup (d2e envPro) (STArr ndim STNil) $                  makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ -                  EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ +                  EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $                      -- the cotangent for this element                      ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))                                         (EVar ext shty IZ)) $                      -- the tape for this element -                    ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) +                    ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ)))                                         (EVar ext shty (IS IZ))) $                      let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ                      in letBinds rebinds $ @@ -865,17 +865,19 @@ drev des = \case                                                 &. #pro (d2ace envPro)                                                 &. #etape (subList (bindingsBinds e0) subtapeE)                                                 &. #prerebinds prerebinds -                                               &. #tape (tapety `SCons` SNil) -                                               &. #ix (shty `SCons` SNil) -                                               &. #darr (STArr ndim (d2 eltty) `SCons` SNil) -                                               &. #tapearr (STArr ndim tapety `SCons` SNil) -                                               &. #sh (shty `SCons` SNil) +                                               &. #tape (auto1 @(Tape e_tape)) +                                               &. #ix (auto1 @shty) +                                               &. #darr (auto1 @(TArr ndim (D2 eltty))) +                                               &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) +                                               &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) +                                               &. #sh (auto1 @shty)                                                 &. #d2acUsed (d2ace (select SAccum usedDes))                                                 &. #d2acEnv (d2ace (select SAccum des)))                                                (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) -                                              ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) +                                              ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv)                                       .> wPro (subList (bindingsBinds e0) subtapeE)) -                                    e2)) +                                    e2) +           (EVar ext (d2 (STArr ndim eltty)) IZ))      }}    EUnit _ e @@ -884,8 +886,11 @@ drev des = \case          subtape          (EUnit ext e1)          sub -        (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ -         weakenExpr (WCopy WSink) e2) +        (EMaybe ext +          (zeroTup (subList (select SMerge des) sub)) +          (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ +             weakenExpr (WCopy (WSink .> WSink)) e2) +          (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ))    EReplicate1Inner _ en e      -- We're allowed to ignore en2 here because the output of 'ei' is discrete. @@ -896,11 +901,14 @@ drev des = \case          subtape          (EReplicate1Inner ext en1 e1)          sub -        (ELet ext (EFold1Inner ext Commut -                      (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) -                      (EZero ext eltty) -                      (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ -          weakenExpr (WCopy WSink) e2) +        (EMaybe ext +          (zeroTup (subList (select SMerge des) sub)) +          (ELet ext (EJust ext (EFold1Inner ext Commut +                        (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) +                        (EZero ext eltty) +                        (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ +            weakenExpr (WCopy (WSink .> WSink)) e2) +          (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))    EIdx0 _ e      | Ret e0 subtape e1 sub e2 <- drev des e @@ -909,7 +917,7 @@ drev des = \case          subtape          (EIdx0 ext e1)          sub -        (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $ +        (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $           weakenExpr (WCopy WSink) e2)    EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" @@ -971,10 +979,13 @@ drev des = \case          (SEYes (SENo subtape))          (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))          sub -        (ELet ext (EReplicate1Inner ext -                     (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) -                     (EVar ext (STArr n (d2 t)) IZ)) $ -         weakenExpr (WCopy (WSink .> WSink)) e2) +        (EMaybe ext +          (zeroTup (subList (select SMerge des) sub)) +          (ELet ext (EJust ext (EReplicate1Inner ext +                                  (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) +                                  (EVar ext (STArr n (d2 t)) IZ))) $ +           weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) +          (EVar ext (d2 (STArr n t)) IZ))    EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e    EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e @@ -1010,13 +1021,17 @@ drev des = \case            (SEYes (SEYes subtape))            (EVar ext at' IZ)            sub -          (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ -                       ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext -                                    (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) -                                    (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))) -                         (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) -                         (EZero ext t)) $ -            weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) +          (EMaybe ext +            (zeroTup (subList (select SMerge des) sub)) +            (ELet ext (EJust ext +                        (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ +                           eif (EOp ext (OEq st) (EPair ext +                                        (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) +                                        (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) +                             (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) +                             (EZero ext t))) $ +              weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) +            (EVar ext (d2 at') IZ))  data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index e8ec0c9..7f49cef 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -20,7 +20,7 @@ type family D2 t where    D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))    D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))    D2 (TMaybe t) = TMaybe (D2 t) -  D2 (TArr n t) = TArr n (D2 t) +  D2 (TArr n t) = TMaybe (TArr n (D2 t))    D2 (TScal t) = D2s t  type family D2s t where @@ -60,7 +60,7 @@ d2 STNil = STNil  d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))  d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))  d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STArr n (d2 t) +d2 (STArr n t) = STMaybe (STArr n (d2 t))  d2 (STScal t) = case t of    STI32 -> STNil    STI64 -> STNil diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index a75fdb8..f843206 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -29,14 +29,14 @@ toTan typ primal der = case typ of                          (Right p, Right d') -> Right (toTan t2 p d')                          _ -> error "Primal and cotangent disagree on Either alternative"    STMaybe t -> liftA2 (toTan t) primal der -  STArr _ t -    | shapeSize (arrayShape der) == 0 -> -        arrayMap (zeroTan t) primal -    | arrayShape primal == arrayShape der -> -        arrayGenerateLin (arrayShape primal) $ \i -> -          toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) -    | otherwise -> -        error "Primal and cotangent disagree on array shape" +  STArr _ t -> case der of +                 Nothing -> arrayMap (zeroTan t) primal +                 Just d +                   | arrayShape primal == arrayShape d -> +                       arrayGenerateLin (arrayShape primal) $ \i -> +                         toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) +                   | otherwise -> +                       error "Primal and cotangent disagree on array shape"    STScal sty -> case sty of      STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der    STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Compile.hs b/src/Compile.hs index 09c3ed5..5501746 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,5 +1,7 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE MagicHash #-}  {-# LANGUAGE MultiWayIf #-} @@ -7,8 +9,10 @@  {-# LANGUAGE TypeApplications #-}  module Compile (compile, debugCSource, debugRefc, emitChecks) where +import Control.Applicative (empty)  import Control.Monad (forM_, when, replicateM)  import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Maybe  import Control.Monad.Trans.State.Strict  import Control.Monad.Trans.Writer.CPS  import Data.Bifunctor (first) @@ -36,10 +40,12 @@ import qualified Prelude  import Array  import AST -import AST.Pretty (ppTy) +import AST.Pretty (ppSTy) +import qualified CHAD.Types as CHAD  import Compile.Exec  import Data  import Interpreter.Rep +import qualified Util.IdGen as IdGen  -- In shape and index arrays, the innermost dimension is on the right (last index). @@ -188,62 +194,59 @@ printCExpr d = \case        ,("/", (7, (7, 8)))        ,("%", (7, (7, 8)))] -repTy :: Ty -> String -repTy (TScal st) = case st of -  TI32 -> "int32_t" -  TI64 -> "int64_t" -  TF32 -> "float" -  TF64 -> "double" -  TBool -> "uint8_t" -repTy t = genStructName t -  repSTy :: STy t -> String -repSTy = repTy . unSTy +repSTy (STScal st) = case st of +  STI32 -> "int32_t" +  STI64 -> "int64_t" +  STF32 -> "float" +  STF64 -> "double" +  STBool -> "uint8_t" +repSTy t = genStructName t -genStructName :: Ty -> String +genStructName :: STy t -> String  genStructName = \t -> "ty_" ++ gen t where    -- all tags start with a letter, so the array mangling is unambiguous. -  gen :: Ty -> String -  gen TNil = "n" -  gen (TPair a b) = 'P' : gen a ++ gen b -  gen (TEither a b) = 'E' : gen a ++ gen b -  gen (TMaybe t) = 'M' : gen t -  gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t -  gen (TScal st) = case st of -    TI32 -> "i" -    TI64 -> "j" -    TF32 -> "f" -    TF64 -> "d" -    TBool -> "b" -  gen (TAccum t) = 'C' : gen t +  gen :: STy t -> String +  gen STNil = "n" +  gen (STPair a b) = 'P' : gen a ++ gen b +  gen (STEither a b) = 'E' : gen a ++ gen b +  gen (STMaybe t) = 'M' : gen t +  gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t +  gen (STScal st) = case st of +    STI32 -> "i" +    STI64 -> "j" +    STF32 -> "f" +    STF64 -> "d" +    STBool -> "b" +  gen (STAccum t) = 'C' : gen t  -- | This function generates the actual struct declarations for each of the  -- types in our language. It thus implicitly "documents" the layout of the  -- types in the C translation. -genStruct :: String -> Ty -> [StructDecl] +genStruct :: String -> STy t -> [StructDecl]  genStruct name topty = case topty of -  TNil -> +  STNil ->      [StructDecl name "" com] -  TPair a b -> -    [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com] -  TEither a b ->  -- 0 -> l, 1 -> r -    [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " l; " ++ repTy b ++ " r; };") com] -  TMaybe t ->  -- 0 -> nothing, 1 -> just -    [StructDecl name ("uint8_t tag; " ++ repTy t ++ " j;") com] -  TArr n t -> +  STPair a b -> +    [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] +  STEither a b ->  -- 0 -> l, 1 -> r +    [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] +  STMaybe t ->  -- 0 -> nothing, 1 -> just +    [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] +  STArr n t ->      -- The buffer is trailed by a VLA for the actual array data. -    [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " xs[];") "" +    [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""      ,StructDecl name (name ++ "_buf *buf;") com] -  TScal _ -> +  STScal _ ->      [] -  TAccum t -> -    [StructDecl name (repTy t ++ " ac;") com] +  STAccum t -> +    [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com]    where -    com = ppTy 0 topty +    com = ppSTy 0 topty  -- State: already-generated (skippable) struct names  -- Writer: the structs in declaration order -genStructs :: Ty -> WriterT (Bag StructDecl) (State (Set String)) () +genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) ()  genStructs ty = do    let name = genStructName ty    seen <- lift $ gets (name `Set.member`) @@ -255,19 +258,19 @@ genStructs ty = do        -- twice (unnecessary because no recursive types, but y'know)        lift $ modify (Set.insert name) -      case ty of -        TNil -> pure () -        TPair a b -> genStructs a >> genStructs b -        TEither a b -> genStructs a >> genStructs b -        TMaybe t -> genStructs t -        TArr _ t -> genStructs t -        TScal _ -> pure () -        TAccum t -> genStructs t +      () <- case ty of +        STNil -> pure () +        STPair a b -> genStructs a >> genStructs b +        STEither a b -> genStructs a >> genStructs b +        STMaybe t -> genStructs t +        STArr _ t -> genStructs t +        STScal _ -> pure () +        STAccum t -> genStructs (CHAD.d2 t)        tell (BList (genStruct name ty))  genAllStructs :: Foldable t => t Ty -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty +genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty  data CompState = CompState    { csStructs :: Set Ty @@ -276,36 +279,48 @@ data CompState = CompState    , csNextId :: Int }    deriving (Show) -type CompM a = State CompState a +newtype CompM a = CompM (State CompState a) +  deriving newtype (Functor, Applicative, Monad) + +runCompM :: CompM a -> (a, CompState) +runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) -genId :: CompM Int -genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +class Monad m => MonadNameGen m where genId :: m Int +instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +instance MonadNameGen IdGen.IdGen where genId = IdGen.genId +instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) -genName' :: String -> CompM String +genName' :: MonadNameGen m => String -> m String  genName' "" = genName  genName' prefix = (prefix ++) . show <$> genId -genName :: CompM String +genName :: MonadNameGen m => m String  genName = genName' "x" +onlyIdGen :: IdGen.IdGen a -> CompM a +onlyIdGen m = CompM $ do +  i1 <- gets csNextId +  let (res, i2) = IdGen.runIdGen' i1 m +  modify (\s -> s { csNextId = i2 }) +  return res +  emit :: Stmt -> CompM () -emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } +emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }  scope :: CompM a -> CompM (a, Bag Stmt)  scope m = do -  stmts <- state $ \s -> (csStmts s, s { csStmts = mempty }) +  stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })    res <- m -  innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts }) +  innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })    return (res, innerStmts)  emitStruct :: STy t -> CompM String -emitStruct ty = do -  let ty' = unSTy ty -  modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } -  return (genStructName ty') +emitStruct ty = CompM $ do +  modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) } +  return (genStructName ty)  emitTLD :: String -> CompM () -emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } +emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }  nameEnv :: SList f env -> SList (Const String) env  nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) @@ -313,7 +328,7 @@ nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg"  compileToString :: SList STy env -> Ex env t -> String  compileToString env expr =    let args = nameEnv env -      (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1) +      (res, s) = runCompM (compile' args expr)        structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))        (arg_pairs, arg_metrics) = @@ -649,7 +664,7 @@ compile' env = \case      x0name <- compileAssign "foldx0" env ex0      arrname <- compileAssign "foldarr" env earr -    zeroRefcountCheck "fold1i" arrname +    zeroRefcountCheck (typeOf earr) "fold1i" arrname      shszname <- genName' "shsz"      -- This n is one less than the shape of the thing we're querying, which is @@ -694,7 +709,7 @@ compile' env = \case      let STArr (SS n) t = typeOf e      argname <- compileAssign "sumarg" env e -    zeroRefcountCheck "sum1i" argname +    zeroRefcountCheck (typeOf e) "sum1i" argname      shszname <- genName' "shsz"      -- This n is one less than the shape of the thing we're querying, like EFold1Inner. @@ -737,7 +752,7 @@ compile' env = \case      lenname <- compileAssign "replen" env elen      argname <- compileAssign "reparg" env earg -    zeroRefcountCheck "replicate1i" argname +    zeroRefcountCheck (typeOf earg) "replicate1i" argname      shszname <- genName' "shsz"      emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -767,7 +782,7 @@ compile' env = \case    EIdx0 _ e -> do      let STArr _ t = typeOf e      arrname <- compileAssign "" env e -    zeroRefcountCheck "idx0" arrname +    zeroRefcountCheck (typeOf e) "idx0" arrname      name <- genName      emit $ SVarDecl True (repSTy t) name                (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) @@ -779,7 +794,7 @@ compile' env = \case    EIdx _ earr eidx -> do      let STArr n t = typeOf earr      arrname <- compileAssign "ixarr" env earr -    zeroRefcountCheck "idx" arrname +    zeroRefcountCheck (typeOf earr) "idx" arrname      idxname <- if fromSNat n > 0  -- prevent an unused-varable warning                   then compileAssign "ixix" env eidx                   else return ""  -- won't be used in this case @@ -803,7 +818,7 @@ compile' env = \case          t = tTup (sreplicate n tIx)      _ <- emitStruct t      name <- compileAssign "" env e -    zeroRefcountCheck "shape" name +    zeroRefcountCheck (typeOf e) "shape" name      resname <- genName      emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)      incrementVarAlways Decrement (typeOf e) name @@ -833,15 +848,15 @@ compile' env = \case      actyname <- emitStruct (STAccum t)      name1 <- compileAssign "" env e1 -    zeroRefcountCheck "with" name1 +    zeroRefcountCheck (typeOf e1) "with" name1 -    mcopy <- copyForWriting t name1 +    mcopy <- copyForWriting (CHAD.d2 t) name1      accname <- genName' "accum"      emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)])      e2' <- compile' (Const accname `SCons` env) e2 -    rettyname <- emitStruct (STPair (typeOf e2) t) +    rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t))      return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")]    EAccum _ t prj eidx eval eacc -> do @@ -1096,7 +1111,7 @@ shapeTupFromLitVars = \n -> go n . reverse  compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr  compileOpGeneral op e1 = do -  let unary cop = return @(State CompState) $ CECall cop [e1] +  let unary cop = return @CompM $ CECall cop [e1]    let binary cop = do          name <- genName          emit $ SVarDecl True (repSTy (opt1 op)) name e1 @@ -1127,7 +1142,7 @@ compileOpGeneral op e1 = do  compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr  compileOpPair op e1 e2 = do -  let binary cop = return @(State CompState) $ CEBinop e1 cop e2 +  let binary cop = return @CompM $ CEBinop e1 cop e2    case op of      OAdd _ -> binary "+"      OMul _ -> binary "*" @@ -1153,7 +1168,7 @@ compileExtremum nameBase opName operator env e = do    let STArr (SS n) t = typeOf e    argname <- compileAssign (nameBase ++ "arg") env e -  zeroRefcountCheck opName argname +  zeroRefcountCheck (typeOf e) opName argname    shszname <- genName' "shsz"    -- This n is one less than the shape of the thing we're querying, which is @@ -1209,7 +1224,7 @@ copyForWriting topty var = case topty of        _ -> do          name <- genName          emit $ SVarDeclUninit (repSTy topty) name -        emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) +        emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))                   (stmts1                    <> pure (SAsg name (CEStruct (repSTy topty)                                          [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) @@ -1225,7 +1240,7 @@ copyForWriting topty var = case topty of        Just e1' -> do          name <- genName          emit $ SVarDeclUninit (repSTy topty) name -        emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) +        emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))                   (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")])))                   (stmts1                    <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) @@ -1301,13 +1316,49 @@ copyForWriting topty var = case topty of    STAccum _ -> error "Compile: Nested accumulators not supported" -zeroRefcountCheck :: String -> String -> CompM () -zeroRefcountCheck opname arrvar = -  when emitChecks $ -    emit $ SVerbatim $ -      "if (__builtin_expect(" ++ arrvar ++ ".buf->refc == 0, 0)) { " ++ -      "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ -      "%p with refc=0\\n\", " ++ arrvar ++ ".buf); abort(); }" +zeroRefcountCheck :: STy t -> String -> String -> CompM () +zeroRefcountCheck toptyp opname topvar = +  when emitChecks $ do +    mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) +    case mstmts of +      Nothing -> return () +      Just stmts -> forM_ stmts emit +  where +    -- | If this returns 'Nothing', no statements need to be generated for this type. +    go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) +    go STNil _ = empty +    go (STPair a b) path = do +      (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) +      return (s1 <> s2) +    go (STEither a b) path = do +      (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) +      return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 +    go (STMaybe a) path = do +      ss <- go a (path++".j") +      return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty +    go (STArr n a) path = do +      ivar <- genName' "i" +      ss <- go a (path++".buf->xs["++ivar++"]") +      shszname <- genName' "shsz" +      let s1 = SVerbatim $ +                 "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ +                 "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ +                 "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }" +      let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) +      let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss +      return (BList [s1, s2, s3]) +    go STScal{} _ = empty +    go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + +    combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) +    combine (MaybeT a) (MaybeT b) = MaybeT $ do +      x <- a +      y <- b +      return $ case (x, y) of +        (Nothing, Nothing) -> Nothing +        (Just x', Nothing) -> Just (x', mempty) +        (Nothing, Just y') -> Just (mempty, y') +        (Just x', Just y') -> Just (x', y')  compose :: Foldable t => t (a -> a) -> a -> a  compose = foldr (.) id diff --git a/src/Data.hs b/src/Data.hs index 60afdd0..e7b3148 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -11,6 +11,7 @@  module Data (module Data, (:~:)(Refl)) where  import Data.Functor.Product +import Data.Some  import Data.Type.Equality  import Unsafe.Coerce (unsafeCoerce) @@ -85,6 +86,10 @@ unSNat :: SNat n -> Nat  unSNat SZ = Z  unSNat (SS n) = S (unSNat n) +reSNat :: Nat -> Some SNat +reSNat Z = Some SZ +reSNat (S n) | Some n' <- reSNat n = Some (SS n') +  fromNat :: Nat -> Int  fromNat Z = 0  fromNat (S m) = succ (fromNat m) diff --git a/src/Example.hs b/src/Example.hs index e234ff4..2c710a1 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -183,7 +183,7 @@ neuralGo =            ELet ext (EConst ext STF64 1.0) $              chad defaultConfig knownEnv neural        (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of -        (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') +        (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')          _ -> undefined        (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0    in trace (ppExpr knownEnv revderiv) $ diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3cc7ae4..ddc3479 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -188,8 +188,7 @@ zeroD2 typ = case typ of    STPair _ _ -> Nothing    STEither _ _ -> Nothing    STMaybe _ -> Nothing -  STArr SZ t -> arrayUnit (zeroD2 t) -  STArr n _ -> emptyArray n +  STArr _ _ -> Nothing    STScal sty -> case sty of                    STI32 -> ()                    STI64 -> () @@ -215,13 +214,16 @@ addD2s typ a b = case typ of      (Nothing, _) -> b      (_, Nothing) -> a      (Just x, Just y) -> Just (addD2s t x y) -  STArr _ t -> -    let sh1 = arrayShape a -        sh2 = arrayShape b -    in if | shapeSize sh1 == 0 -> b -          | shapeSize sh2 == 0 -> a -          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i)) -          | otherwise -> error "Plus of inconsistently shaped arrays" +  STArr _ t -> case (a, b) of +    (Nothing, _) -> b +    (_, Nothing) -> a +    (Just x, Just y) -> +      let sh1 = arrayShape x +          sh2 = arrayShape y +      in if | shapeSize sh1 == 0 -> Just y +            | shapeSize sh2 == 0 -> Just x +            | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) +            | otherwise -> error "Plus of inconsistently shaped arrays"    STScal sty -> case sty of      STI32 -> ()      STI64 -> () @@ -238,7 +240,7 @@ onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx v  onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))  onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)  onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = -  runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx +  Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx  withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))  withAccum t _ initval f = AcM $ do @@ -253,7 +255,7 @@ newAcZero = \case    STPair{} -> newIORef Nothing    STEither{} -> newIORef Nothing    STMaybe _ -> newIORef Nothing -  STArr n _ -> newIORef (emptyArray n) +  STArr _ _ -> newIORef Nothing    STScal sty -> case sty of      STI32 -> return ()      STI64 -> return () @@ -268,7 +270,7 @@ newAcSparse typ prj idx val = case (typ, prj) of    (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val    (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val    (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val -  (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val +  (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val    (STScal sty, SAPHere) -> case sty of      STI32 -> return ()      STI64 -> return () @@ -286,7 +288,7 @@ newAcSparse typ prj idx val = case (typ, prj) of    (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val -  (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val +  (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val    (STAccum{}, _) -> error "Accumulators not allowed in source program" @@ -309,7 +311,7 @@ readAcSparse typ val = case typ of    STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val    STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val    STMaybe t -> traverse (readAcSparse t) =<< readIORef val -  STArr _ t -> traverse (readAcSparse t) =<< readIORef val +  STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val    STScal sty -> case sty of      STI32 -> return ()      STI64 -> return () @@ -360,32 +362,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of                               (\ac -> accumAddSparse t1 prj' ac idx val)    (STArr _ t1, SAPHere) -> -    let add ac = forM_ [0 .. arraySize ac - 1] $ \i -> -                   unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i) -    in if arraySize val == 0 -         then return () -         else AcM $ join $ atomicModifyIORef' ref $ \ac -> -                if arraySize ac == 0 -                  then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val -                               join $ atomicModifyIORef' ref $ \ac' -> -                                 if arraySize ac == 0 -                                   then (newac, return ()) -                                   else (ac', add ac')) -                  else (ac, add ac) +    case val of +      Nothing -> return () +      Just val' -> +        realiseMaybeSparse ref +          (arrayMapM (newAcSparse t1 SAPHere ()) val') +          (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> +                    accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))    (STArr n t1, SAPArrIdx prj' _) ->      let ((arrindex', arrsh'), idx') = idx          arrindex = unTupRepIdx IxNil IxCons n arrindex'          arrsh = unTupRepIdx ShNil ShCons n arrsh'          linindex = toLinearIndex arrsh arrindex -        add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val -    in AcM $ join $ atomicModifyIORef' ref $ \ac -> -         if arraySize ac == 0 -           then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx -                        join $ atomicModifyIORef' ref $ \ac' -> -                          if arraySize ac == 0 -                            then (newac, return ()) -                            else (ac', add ac')) -           else (ac, add ac) +    in realiseMaybeSparse ref +         (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) +         (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)    (STScal sty, SAPHere) -> AcM $ case sty of      STI32 -> return () diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index f84f4e7..be2a4cc 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -29,8 +29,7 @@ type family RepAc t where    RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b))    RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b)))    RepAc (TMaybe t) = IORef (Maybe (RepAc t)) -  -- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero. -  RepAc (TArr n t) = IORef (Array n (RepAc t))  -- empty array is zero +  RepAc (TArr n t) = IORef (Maybe (Array n (RepAc t)))    RepAc (TScal sty) = RepAcScal sty    RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") diff --git a/src/Util/IdGen.hs b/src/Util/IdGen.hs index fcfb6e7..3f6611d 100644 --- a/src/Util/IdGen.hs +++ b/src/Util/IdGen.hs @@ -2,11 +2,12 @@  {-# LANGUAGE GeneralizedNewtypeDeriving #-}  module Util.IdGen where +import Control.Monad.Fix  import Control.Monad.Trans.State.Strict  newtype IdGen a = IdGen (State Int a) -  deriving newtype (Functor, Applicative, Monad) +  deriving newtype (Functor, Applicative, Monad, MonadFix)  genId :: IdGen Int  genId = IdGen (state (\i -> (i, i + 1))) | 
