{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Interpreter ( interpret, interpret', Value, ) where import Control.Monad (foldM) import Data.Int (Int64) import Data.Proxy import System.IO.Unsafe (unsafePerformIO) import Array import AST import CHAD.Types import Data import Interpreter.Rep newtype AcM s a = AcM (IO a) deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m interpret :: Ex '[] t -> Rep t interpret e = runAcM (interpret' SNil e) newtype Value t = Value (Rep t) interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t) interpret' env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do x <- interpret' env a interpret' (Value x `SCons` env) b EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b EFst _ e -> fst <$> interpret' env e ESnd _ e -> snd <$> interpret' env e ENil _ -> return () EInl _ _ e -> Left <$> interpret' env e EInr _ _ e -> Right <$> interpret' env e ECase _ e a b -> interpret' env e >>= \case Left x -> interpret' (Value x `SCons` env) a Right y -> interpret' (Value y `SCons` env) b ENothing _ _ -> _ EJust _ _ -> _ EMaybe _ _ _ _ -> _ EConstArr _ _ _ v -> return v EBuild1 _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arrayGenerateLinM (ShNil `ShCons` n) (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) EFold1Inner _ a b -> do let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a arr <- interpret' env b let sh `ShCons` n = arrayShape arr arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e sh `ShCons` n = arrayShape arr numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) EReplicate1Inner _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arr <- interpret' env b let sh = arrayShape arr arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx)) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e EWith e1 e2 -> do initval <- interpret' env e1 withAccum (typeOf e1) initval $ \accum -> interpret' (Value accum `SCons` env) e2 EAccum i e1 e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 accumAdd accum i idx val EZero t -> do return $ makeZero t EPlus t a b -> do a' <- interpret' env a b' <- interpret' env b return $ makePlus t a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t interpretOp op arg = case op of OAdd st -> numericIsNum st $ uncurry (+) arg OMul st -> numericIsNum st $ uncurry (*) arg ONeg st -> numericIsNum st $ negate arg OLt st -> numericIsNum st $ uncurry (<) arg OLe st -> numericIsNum st $ uncurry (<=) arg OEq st -> numericIsNum st $ uncurry (==) arg ONot -> not arg OIf -> if arg then Left () else Right () makeZero :: STy t -> Rep (D2 t) makeZero typ = case typ of STNil -> () STPair _ _ -> Left () STEither _ _ -> Left () STMaybe _ -> Nothing STArr n _ -> emptyArray n STScal sty -> case sty of STI32 -> () STI64 -> () STF32 -> 0.0 STF64 -> 0.0 STBool -> () STAccum{} -> error "Zero of Accum" makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) makePlus typ a b = case typ of STNil -> () STPair t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2) STEither t1 t2 -> case (a, b) of (Left (), _) -> b (_, Left ()) -> a (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y)) (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y)) _ -> error "Plus of inconsistent Eithers" STArr _ t -> let sh1 = arrayShape a sh2 = arrayShape b in if | shapeSize sh1 == 0 -> b | shapeSize sh2 == 0 -> a | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i)) | otherwise -> error "Plus of inconsistently shaped arrays" STScal sty -> case sty of STI32 -> () STI64 -> () STF32 -> a + b STF64 -> a + b STBool -> () STAccum{} -> error "Plus of Accum" numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id numericIsNum STI64 = id numericIsNum STF32 = id numericIsNum STF64 = id unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n unTupRepIdx nil _ SZ _ = nil unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i tupRepIdx :: (forall m. f (S m) -> (f m, Int)) -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) tupRepIdx _ SZ _ = () tupRepIdx uncons (SS n) tup = let (tup', i) = uncons tup in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i) ixUncons :: Index (S n) -> (Index n, Int) ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a foldl1M _ [] = error "foldl1M: empty list" foldl1M f (tophead : toptail) = foldM f tophead toptail