diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/CHAD/Types/ToTan.hs | 5 | ||||
| -rw-r--r-- | src/Compile.hs | 205 | ||||
| -rw-r--r-- | src/Example/GMM.hs | 4 | ||||
| -rw-r--r-- | src/ForwardAD.hs | 28 | ||||
| -rw-r--r-- | src/Interpreter.hs | 98 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 40 | ||||
| -rw-r--r-- | src/Simplify.hs | 2 | 
7 files changed, 237 insertions, 145 deletions
| diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index f843206..87c01cb 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -40,3 +40,8 @@ toTan typ primal der = case typ of    STScal sty -> case sty of      STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der    STAccum{} -> error "Accumulators not allowed in input program" +  STLEither t1 t2 -> case (primal, der) of +                       (_, Nothing) -> Nothing +                       (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) +                       (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) +                       _ -> error "Primal and cotangent disagree on LEither alternative" diff --git a/src/Compile.hs b/src/Compile.hs index 503c342..6ba3a39 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1001,16 +1001,6 @@ compile' env = \case      return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]    EAccum _ t prj eidx eval eacc -> do -    nameidx <- compileAssign "acidx" env eidx -    nameval <- compileAssign "acval" env eval - -    -- Generate the variable manually because this one has to be non-const. -    -- TODO: old code: -    --   eacc' <- compile' env eacc -    --   nameacc <- genName' "acac" -    --   emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' -    nameacc <- compileAssign "acac" env eacc -      let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a          -- full zero array with the given zero info (for the type SMTArr n t1).          initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () @@ -1041,77 +1031,66 @@ compile' env = \case          initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi          initZeroFromMemset SMTScal{} = Nothing -    let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) -        initZero :: SMTy a -> String -> String -> CompM () -        initZero SMTNil _ _ = return () -        initZero (SMTPair t1 t2) v vzi = do -          initZero t1 (v++".a") (vzi++".a") -          initZero t2 (v++".b") (vzi++".b") -        initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") -        initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") -        initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi -        initZero (SMTScal sty) v _ = case sty of +    let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) +        initZeroZI :: SMTy a -> String -> String -> CompM () +        initZeroZI SMTNil _ _ = return () +        initZeroZI (SMTPair t1 t2) v vzi = do +          initZeroZI t1 (v++".a") (vzi++".a") +          initZeroZI t2 (v++".b") (vzi++".b") +        initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") +        initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") +        initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi +        initZeroZI (SMTScal sty) v _ = case sty of            STI32 -> emit $ SAsg v (CELit "0")            STI64 -> emit $ SAsg v (CELit "0l")            STF32 -> emit $ SAsg v (CELit "0.0f")            STF64 -> emit $ SAsg v (CELit "0.0") -    let -- | Dereference an accumulation value. Sparse components encountered -        -- along the way are initialised before proceeding downwards. At the -        -- point where we have the projected accumulator position available, -        -- the handler will be invoked with a variable name pointing to the -        -- projected position. -        -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler) -        accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM () -        accumRef _ SAPHere v _ k = k v - -        accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k -        accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k - -        accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do -          ((), stmtsInit1) <- scope $ initZero ta (v++".l") i -          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) -                   (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty -          accumRef ta prj' (v++".l") i k -        accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do -          ((), stmtsInit2) <- scope $ initZero tb (v++".r") i -          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) -                   (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty -          accumRef tb prj' (v++".r") i k - -        accumRef (SMTMaybe tj) (SAPJust prj') v i k = do -          ((), stmtsInit1) <- scope $ initZero tj (v++".j") i -          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) -                   (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty -          accumRef tj prj' (v++".j") i k - -        accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do -          when emitChecks $ do -            let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" -            forM_ (zip3 [0::Int ..] -                        (indexTupleComponents n (i++".a.a")) -                        (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do -              let a .||. b = CEBinop a "||" b -              emit $ SIf (CEBinop ixcomp "<" (CELit "0") -                          .||. -                          CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) -                          .||. -                          CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) -                       (pure $ SVerbatim $ -                          "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ -                          "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ -                          v ++ ".buf" ++ -                          concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ -                          concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ -                          concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ -                          "); " ++ -                          "return false;") -                       mempty - -          accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k +    let -- Initialise an uninitialised accumulation value, potentially already +        -- with the addend, potentially to zero depending on the nature of the +        -- projection. +        -- 1. If the projection indexes only through dense monoids before +        --    reaching SAPHere, the thing cannot be initialised to zero with +        --    only an AcIdx; it would need to model a zero after the addend, +        --    which is stupid and redundant. In this case, we return Left: +        --    (accumulation value) (AcIdx value) (addend value). +        --    The addend is copied, not consumed. (We can't reliably _always_ +        --    consume it, so it's not worth trying to do it sometimes.) +        -- 2. Otherwise, a sparse monoid is found along the way, and we can +        --    initalise the dense prefix of the path to zero by setting the +        --    indexed-through sparse value to a sparse zero. Afterwards, the +        --    main recursion can proceed further. In this case, we return +        --    Right: (accumulation value) (AcIdx value) +        -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) +        initZeroChunk :: SMTy a -> SAcPrj p a b +                      -> Either (String -> String -> String -> CompM ())  -- dense initialisation with addend +                                (String -> String -> CompM ())  -- zero initialisation of sparse chunk +        initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of +          -- reached target before the first sparse constructor +          (t1           , SAPHere    ) -> Left $ \v _ addend -> do +                                            incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend +                                            emit $ SAsg v (CELit addend) +          -- sparse types +          (SMTLEither{} , _          ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") +          (SMTMaybe{}   , _          ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") +          -- dense types +          (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do +                                            f (v++".a") (i++".a") +                                            initZeroZI t2 (v++".b") (i++".b") +          (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do +                                            initZeroZI t1 (v++".a") (i++".a") +                                            f (v++".b") (i++".b") +          (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do +                                             initZeroArray n t1 v (i++".a.b") +                                             linidxvar <- genName' "li" +                                             emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) +                                             f (v++".buf->xs["++linidxvar++"]") (i++".b") +          where +            applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i +            applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i      let -- Add a value (s) into an existing accumulation value (d). If a sparse -        -- component of d is encountered, s is simply written there. +        -- component of d is encountered, s is copied there.          add :: SMTy a -> String -> String -> CompM ()          add SMTNil _ _ = return ()          add (SMTPair t1 t2) d s = do @@ -1165,16 +1144,88 @@ compile' env = \case                         mempty            shsizename <- genName' "acshsz" -          emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")) +          emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)            ivar <- genName' "i"            ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")            emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)                     stmts1          add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" +    let -- | Dereference an accumulation value and add a given value to that +        -- position. Sparse components encountered along the way are +        -- initialised before proceeding downwards. +        -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) +        accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () +        accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend + +        accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend +        accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend + +        accumRef (SMTLEither ta tb) prj0 v i addend = do +          let chunkres = case prj0 of SAPLeft  prj' -> initZeroChunk ta prj' +                                      SAPRight prj' -> initZeroChunk tb prj' +              subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") +              tagval = case prj0 of SAPLeft{}  -> "1" +                                    SAPRight{} -> "2" +          ((), stmtsAdd) <- scope $ case prj0 of SAPLeft  prj' -> accumRef ta prj' subv i addend +                                                 SAPRight prj' -> accumRef tb prj' subv i addend +          case chunkres of +            Left densef -> do +              ((), stmtsSet) <- scope $ densef subv i addend +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) +                       stmtsAdd  -- TODO: emit check for consistency of tags? +            Right sparsef -> do +              ((), stmtsInit) <- scope $ sparsef subv i +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty +              forM_ stmtsAdd emit + +        accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do +          case initZeroChunk tj prj' of +            Left densef -> do +              ((), stmtsSet1) <- scope $ densef (v++".j") i addend +              ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) +                       stmtsAdd1 +            Right sparsef -> do +              ((), stmtsInit1) <- scope $ sparsef (v++".j") i +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty +              accumRef tj prj' (v++".j") i addend + +        accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do +          when emitChecks $ do +            let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" +            forM_ (zip3 [0::Int ..] +                        (indexTupleComponents n (i++".a.a")) +                        (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do +              let a .||. b = CEBinop a "||" b +              emit $ SIf (CEBinop ixcomp "<" (CELit "0") +                          .||. +                          CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) +                          .||. +                          CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) +                       (pure $ SVerbatim $ +                          "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ +                          "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ +                          v ++ ".buf" ++ +                          concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ +                          concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ +                          concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ +                          "); " ++ +                          "return false;") +                       mempty + +          accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + +    nameidx <- compileAssign "acidx" env eidx +    nameval <- compileAssign "acval" env eval +    nameacc <- compileAssign "acac" env eacc +      emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" -    accumRef t prj (nameacc++".buf->ac") nameidx $ \dest -> -      add (acPrjTy prj t) dest nameval +    accumRef t prj (nameacc++".buf->ac") nameidx nameval      emit $ SVerbatim $ "// compile EAccum end"      incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs index 12bbd98..206e534 100644 --- a/src/Example/GMM.hs +++ b/src/Example/GMM.hs @@ -31,10 +31,10 @@ import Language  --   <https://tomsmeding.com/f/master.pdf>  --  -- The 'wrong' argument, when set to True, changes the objective function to --- one with a bug that makes a certain `build` result unused. This triggers +-- one with a bug that makes a certain `build` result unused. This  -- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's  -- dense, even though it may be a zero (i.e. empty). The "unused" test in --- test/Main.hs tries to isolate this test, but the wrong version of +-- test/Main.hs tries to isolate this case, but the wrong version of  -- gmmObjective is here to check (after that bug is fixed) whether it really  -- fixes the original bug.  gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b7036dd..5756f96 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -29,6 +29,7 @@ type family Tan t where    Tan (TMaybe t) = TMaybe (Tan t)    Tan (TArr n t) = TArr n (Tan t)    Tan (TScal t) = TanS t +  Tan (TLEither a b) = TLEither (Tan a) (Tan b)  type family TanS t where    TanS TI32 = TNil @@ -54,6 +55,11 @@ tanty (STScal t) = case t of    STF64 -> STScal STF64    STBool -> STNil  tanty STAccum{} = error "Accumulators not allowed in input program" +tanty (STLEither a b) = STLEither (tanty a) (tanty b) + +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env  zeroTan :: STy t -> Rep t -> Rep (Tan t)  zeroTan STNil () = () @@ -69,6 +75,9 @@ zeroTan (STScal STF32) _ = 0.0  zeroTan (STScal STF64) _ = 0.0  zeroTan (STScal STBool) _ = ()  zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))  tanScalars :: STy t -> Rep (Tan t) -> [Double]  tanScalars STNil () = [] @@ -84,6 +93,9 @@ tanScalars (STScal STF32) x = [realToFrac x]  tanScalars (STScal STF64) x = [x]  tanScalars (STScal STBool) _ = []  tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y  tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]  tanEScalars SNil SNil = [] @@ -111,6 +123,10 @@ unzipDN (STScal ty) d = case ty of    STF64 -> d    STBool -> (d, ())  unzipDN STAccum{} _ = error "Accumulators not allowed in input program" +unzipDN (STLEither a b) d = case d of +  Nothing -> (Nothing, Nothing) +  Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) +  Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)  dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double  dotprodTan STNil _ _ = 0.0 @@ -137,6 +153,12 @@ dotprodTan (STScal ty) x y = case ty of    STF64 -> x * y    STBool -> 0.0  dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" +dotprodTan (STLEither a b) x y = case (x, y) of +  (Nothing, _) -> 0.0  -- 0 * y = 0 +  (_, Nothing) -> 0.0  -- x * 0 = 0 +  (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' +  (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' +  _ -> error "dotprodTan: incompatible LEither alternatives"  -- -- Primal expression must be duplicable  -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -174,6 +196,7 @@ dnConst (STScal t) = case t of    STF64 -> (,0.0)    STBool -> id  dnConst STAccum{} = error "Accumulators not allowed in input program" +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))  -- | Given a function that computes the forward derivative for a particular  -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -204,6 +227,11 @@ dnOnehots (STScal t) x = case t of    STF64 -> \f -> f (x, 1.0)    STBool -> \_ -> ()  dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" +dnOnehots (STLEither t1 t2) e = +  case e of +    Nothing -> \_ -> Nothing +    Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) +    Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))  dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)  dnConstEnv SNil SNil = SNil diff --git a/src/Interpreter.hs b/src/Interpreter.hs index af11de8..d7916d8 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -35,7 +35,6 @@ import Debug.Trace  import Array  import AST  import AST.Pretty -import CHAD.Types  import Data  import Interpreter.Rep @@ -253,7 +252,7 @@ withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Re  withAccum t _ initval f = AcM $ do    accum <- newAcDense t initval    out <- unAcM $ f accum -  val <- readAcSparse t accum +  val <- readAc t accum    return (out, val)  newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) @@ -300,81 +299,62 @@ onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =        !linindex = toLinearIndex arrsh arrindex    in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) -readAcSparse :: SMTy t -> RepAc t -> IO (Rep t) -readAcSparse typ val = case typ of +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of    SMTNil -> return () -  SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val -  SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val -  SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val -  SMTArr _ t -> traverse (readAcSparse t) val +  SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val +  SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val +  SMTMaybe t -> traverse (readAc t) =<< readIORef val +  SMTArr _ t -> traverse (readAc t) val    SMTScal _ -> readIORef val -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of -  (STNil, SAPHere) -> return () - -  (STPair t1 t2, SAPHere) -> -    case val of -      Nothing -> return () -      Just (val1, val2) -> -        realiseMaybeSparse ref ((,) <$> newAcDense t1 val1 -                                    <*> newAcDense t2 val2) -                               (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 -                                                  accumAddSparse t2 SAPHere ac2 () val2) -  (STPair t1 t2, SAPFst prj') -> -    realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) -                           (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) -  (STPair t1 t2, SAPSnd prj') -> -    realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) -                           (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - -  (STEither{}, SAPHere) -> +accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () +accumAddDense typ ref val = case typ of +  SMTNil -> return () +  SMTPair t1 t2 -> do +    accumAddDense t1 (fst ref) (fst val) +    accumAddDense t2 (snd ref) (snd val) +  SMTLEither{} ->      case val of        Nothing -> return ()        Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1        Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 -  (STEither t1 _, SAPLeft prj') -> +  SMTMaybe{} -> +    case val of +      Nothing -> return () +      Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' +  SMTArr _ t1 -> +    forM_ [0 .. arraySize ref - 1] $ \i -> +      accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) +  SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () +accumAddSparse typ prj ref idx val = case (typ, prj) of +  (_, SAPHere) -> accumAddDense typ ref val + +  (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val +  (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val + +  (SMTLEither t1 _, SAPLeft prj') ->      realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)                             (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val                                    Right{} -> error "Mismatched Either in accumAddSparse (r +l)") -  (STEither _ t2, SAPRight prj') -> +  (SMTLEither _ t2, SAPRight prj') ->      realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)                             (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val                                    Left{} -> error "Mismatched Either in accumAddSparse (l +r)") -  (STMaybe{}, SAPHere) -> -    case val of -      Nothing -> return () -      Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' -  (STMaybe t1, SAPJust prj') -> -      realiseMaybeSparse ref (newAcSparse t1 prj' idx val) -                             (\ac -> accumAddSparse t1 prj' ac idx val) +  (SMTMaybe t1, SAPJust prj') -> +    realiseMaybeSparse ref (newAcSparse t1 prj' idx val) +                           (\ac -> accumAddSparse t1 prj' ac idx val) -  (STArr _ t1, SAPHere) -> -    case val of -      Nothing -> return () -      Just val' -> -        realiseMaybeSparse ref -          (arrayMapM (newAcDense t1) val') -          (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> -                    accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) -  (STArr n t1, SAPArrIdx prj') -> -    let ((arrindex', arrsh'), idx') = idx +  (SMTArr n t1, SAPArrIdx prj') -> +    let ((arrindex', ziarr), idx') = idx          arrindex = unTupRepIdx IxNil IxCons n arrindex' -        arrsh = unTupRepIdx ShNil ShCons n arrsh' +        arrsh = arrayShape ziarr          linindex = toLinearIndex arrsh arrindex -    in realiseMaybeSparse ref -         (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) -         (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) - -  (STScal sty, SAPHere) -> AcM $ case sty of -    STI32 -> return () -    STI64 -> return () -    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STBool -> return () +    in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val -  (STAccum{}, _) -> error "Accumulators not allowed in source program"  realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()  realiseMaybeSparse ref makeval modifyval = diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 1226b0c..070ba4c 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -1,11 +1,16 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE UndecidableInstances #-}  module Interpreter.Rep where +import Control.DeepSeq +import Data.Coerce (coerce)  import Data.List (intersperse, intercalate)  import Data.Foldable (toList)  import Data.IORef +import GHC.Exts (withDict)  import Array  import AST @@ -58,12 +63,12 @@ showValue d (STArr _ t) arr = showParen (d > 10) $    . showString " ["    . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))    . showString "]" -showValue _ (STScal sty) x = case sty of -  STF32 -> shows x -  STF64 -> shows x -  STI32 -> shows x -  STI64 -> shows x -  STBool -> shows x +showValue d (STScal sty) x = case sty of +  STF32 -> showsPrec d x +  STF64 -> showsPrec d x +  STI32 -> showsPrec d x +  STI64 -> showsPrec d x +  STBool -> showsPrec d x  showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"  showValue _ (STLEither _ _) Nothing = showString "LNil"  showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x @@ -75,3 +80,26 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"      showEntries :: SList STy env -> SList Value env -> [String]      showEntries SNil SNil = []      showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs + +rnfRep :: STy t -> Rep t -> () +rnfRep STNil () = () +rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y +rnfRep (STEither a _) (Left x) = rnfRep a x +rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STMaybe _) Nothing = () +rnfRep (STMaybe t) (Just x) = rnfRep t x +rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = +  withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) +rnfRep (STScal t) x = case t of +  STI32 -> rnf x +  STI64 -> rnf x +  STF32 -> rnf x +  STF64 -> rnf x +  STBool -> rnf x +rnfRep STAccum{} _ = error "Cannot rnf accumulators" +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y + +instance KnownTy t => NFData (Value t) where +  rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/Simplify.hs b/src/Simplify.hs index 228f265..2a1d3b6 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -10,7 +10,7 @@  {-# LANGUAGE TypeOperators #-}  module Simplify (    simplifyN, simplifyFix, -  SimplifyConfig(..), simplifyWith, simplifyFixWith, +  SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,  ) where  import Data.Function (fix) | 
