diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-13 23:07:04 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-13 23:07:04 +0200 | 
| commit | 94938d648e021d2ace0f3b7bf383d256449d619f (patch) | |
| tree | ef077de27b08027c7117761c3efc7d29b7ad3d56 /src/Interpreter.hs | |
| parent | 3d8a6cca424fc5279c43a266900160feb28aa715 (diff) | |
WIP better zero/plus, fixing Accum (...)
The accumulator implementation was wrong because it forgot (in accumAdd)
to take into account that values may be variably-sized. Furthermore, it
was also complexity-inefficient because it did not build up a sparse
value. Thus let's go for the Haskell-interpreter-equivalent of what a
real, fast, compiled implementation would do: just a tree with mutable
variables. In practice one can decide to indeed flatten parts of that
tree, i.e. using a tree representation for nested pairs is bad, but that
should have been done _before_ execution and for _all_ occurrences of
that type fragment, not live at runtime by the accumulator
implementation.
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 142 | 
1 files changed, 86 insertions, 56 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 7ffb14c..8728ec0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,39 +1,44 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-}  module Interpreter (    interpret,    interpret',    Value, -  NoAccum(..), -  unAccum,  ) where +import Control.Monad (foldM)  import Data.Int (Int64)  import Data.Proxy +import System.IO.Unsafe (unsafePerformIO) +import Array  import AST +import CHAD.Types  import Data -import Array -import Interpreter.Accum  import Interpreter.Rep -import Control.Monad (foldM) -interpret :: NoAccum t => Ex '[] t -> Rep t -interpret e = runAcM (go e) -  where -    go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t) -    go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e' +newtype AcM s a = AcM (IO a) +  deriving newtype (Functor, Applicative, Monad) -newtype Value s t = Value (Rep' s t) +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m -interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t) +interpret :: Ex '[] t -> Rep t +interpret e = runAcM (interpret' SNil e) + +newtype Value t = Value (Rep t) + +interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)  interpret' env = \case    EVar _ _ i -> case slistIdx env i of Value x -> return x    ELet _ a b -> do @@ -48,14 +53,17 @@ interpret' env = \case    ECase _ e a b -> interpret' env e >>= \case                       Left x -> interpret' (Value x `SCons` env) a                       Right y -> interpret' (Value y `SCons` env) b +  ENothing _ _ -> _ +  EJust _ _ -> _ +  EMaybe _ _ _ _ -> _    EConstArr _ _ _ v -> return v    EBuild1 _ a b -> do      n <- fromIntegral @Int64 @Int <$> interpret' env a      arrayGenerateLinM (ShNil `ShCons` n)                        (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)    EBuild _ dim a b -> do -    sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a -    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b) +    sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a +    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)    EFold1Inner _ a b -> do      let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a      arr <- interpret' env b @@ -75,9 +83,9 @@ interpret' env = \case    EConst _ _ v -> return v    EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e    EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) -  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b) -  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e -  EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e +  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) +  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e +  EOp _ op e -> interpretOp op <$> interpret' env e    EWith e1 e2 -> do      initval <- interpret' env e1      withAccum (typeOf e1) initval $ \accum -> @@ -87,10 +95,16 @@ interpret' env = \case      val <- interpret' env e2      accum <- interpret' env e3      accumAdd accum i idx val +  EZero t -> do +    return $ makeZero t +  EPlus t a b -> do +    a' <- interpret' env a +    b' <- interpret' env b +    return $ makePlus t a' b'    EError _ s -> error $ "Interpreter: Program threw error: " ++ s -interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t -interpretOp _ op arg = case op of +interpretOp :: SOp a t -> Rep a -> Rep t +interpretOp op arg = case op of    OAdd st -> numericIsNum st $ uncurry (+) arg    OMul st -> numericIsNum st $ uncurry (*) arg    ONeg st -> numericIsNum st $ negate arg @@ -100,23 +114,66 @@ interpretOp _ op arg = case op of    ONot -> not arg    OIf -> if arg then Left () else Right () +makeZero :: STy t -> Rep (D2 t) +makeZero typ = case typ of +  STNil -> () +  STPair _ _ -> Left () +  STEither _ _ -> Left () +  STMaybe _ -> Nothing +  STArr n _ -> emptyArray n +  STScal sty -> case sty of +                  STI32 -> () +                  STI64 -> () +                  STF32 -> 0.0 +                  STF64 -> 0.0 +                  STBool -> () +  STAccum{} -> error "Zero of Accum" + +makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) +makePlus typ a b = case typ of +  STNil -> () +  STPair t1 t2 -> case (a, b) of +    (Left (), _) -> b +    (_, Left ()) -> a +    (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2) +  STEither t1 t2 -> case (a, b) of +    (Left (), _) -> b +    (_, Left ()) -> a +    (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y)) +    (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y)) +    _ -> error "Plus of inconsistent Eithers" +  STArr _ t -> +    let sh1 = arrayShape a +        sh2 = arrayShape b +    in if | shapeSize sh1 == 0 -> b +          | shapeSize sh2 == 0 -> a +          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i)) +          | otherwise -> error "Plus of inconsistently shaped arrays" +  STScal sty -> case sty of +    STI32 -> () +    STI64 -> () +    STF32 -> a + b +    STF64 -> a + b +    STBool -> () +  STAccum{} -> error "Plus of Accum" +  numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r  numericIsNum STI32 = id  numericIsNum STI64 = id  numericIsNum STF32 = id  numericIsNum STF64 = id -unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m)) -            -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n -unTupRepIdx _ nil _    SZ _ = nil -unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i +unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) +            -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n +unTupRepIdx nil _    SZ _ = nil +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i -tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int)) -          -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx)) -tupRepIdx _ _      SZ _ = () -tupRepIdx p uncons (SS n) tup = +tupRepIdx :: (forall m. f (S m) -> (f m, Int)) +          -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) +tupRepIdx _      SZ _ = () +tupRepIdx uncons (SS n) tup =    let (tup', i) = uncons tup -  in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i) +  in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)  ixUncons :: Index (S n) -> (Index n, Int)  ixUncons (IxCons idx i) = (idx, i) @@ -124,33 +181,6 @@ ixUncons (IxCons idx i) = (idx, i)  shUncons :: Shape (S n) -> (Shape n, Int)  shUncons (ShCons idx i) = (idx, i) -class NoAccum t where -  noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t -instance NoAccum TNil where -  noAccum _ _ = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where -  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where -  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance NoAccum t => NoAccum (TArr n t) where -  noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl -instance NoAccum (TScal t) where -  noAccum _ _ = Refl - -unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t)) -unAccum _ STNil = Just Dict -unAccum p (STPair t1 t2) -  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict -  | otherwise = Nothing -unAccum p (STEither t1 t2) -  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict -  | otherwise = Nothing -unAccum p (STArr _ t) -  | Just Dict <- unAccum p t = Just Dict -  | otherwise = Nothing -unAccum _ STScal{} = Just Dict -unAccum _ STAccum{} = Nothing -  foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a  foldl1M _ [] = error "foldl1M: empty list"  foldl1M f (tophead : toptail) = foldM f tophead toptail | 
