summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs25
1 files changed, 17 insertions, 8 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 58d79a5..f8e7e98 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -99,6 +99,15 @@ interpret'Rec env = \case
EMaybe _ a b e ->
let STMaybe t1 = typeOf e
in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e
+ ELNil _ _ _ -> return Nothing
+ ELInl _ _ e -> Just . Left <$> interpret' env e
+ ELInr _ _ e -> Just . Right <$> interpret' env e
+ ELCase _ e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Nothing -> interpret' env a
+ Just (Left x) -> interpret' (V t1 x `SCons` env) b
+ Just (Right y) -> interpret' (V t2 y `SCons` env) c
EConstArr _ _ _ v -> return v
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
@@ -136,9 +145,9 @@ interpret'Rec env = \case
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
- EIdx _ a b
- | STArr n _ <- typeOf a
- -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
+ EIdx _ a b ->
+ let STArr n _ = typeOf a
+ in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
ECustom _ t1 t2 _ pr _ _ e1 e2 -> do
@@ -154,8 +163,8 @@ interpret'Rec env = \case
val <- interpret' env e2
accum <- interpret' env e3
accumAddSparse t p accum idx val
- EZero _ t -> do
- return $ zeroD2 t
+ EZero _ t ezi -> do
+ return $ zeroD2 t ezi
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
@@ -250,7 +259,7 @@ onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx
onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val))
onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))
onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)
-onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
+onehotD2 (SAPArrIdx prj) (STArr n a) idx val =
Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
@@ -299,7 +308,7 @@ newAcSparse typ prj idx val = case (typ, prj) of
(STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
- (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val
+ (STArr n t, SAPArrIdx prj') -> newIORef . Just =<< newAcArray n t prj' idx val
(STAccum{}, _) -> error "Accumulators not allowed in source program"
@@ -380,7 +389,7 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of
(arrayMapM (newAcSparse t1 SAPHere ()) val')
(\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
- (STArr n t1, SAPArrIdx prj' _) ->
+ (STArr n t1, SAPArrIdx prj') ->
let ((arrindex', arrsh'), idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = unTupRepIdx ShNil ShCons n arrsh'