aboutsummaryrefslogtreecommitdiff
path: root/src/Language.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
commit57779d4303f377004705c8da06a5ac46177950b2 (patch)
tree0407089403d3d5c2de778c1aab7aed8adf2d01c0 /src/Language.hs
parent351667a3ff14c96a8dfe3a2f1dd76b6e1a996542 (diff)
drevLambda works, TODO D[map]HEADmaster
Diffstat (limited to 'src/Language.hs')
-rw-r--r--src/Language.hs32
1 files changed, 30 insertions, 2 deletions
diff --git a/src/Language.hs b/src/Language.hs
index 31b4b87..c1a6248 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
@@ -15,6 +16,8 @@ module Language (
Lookup,
) where
+import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol)
+
import Array
import AST
import AST.Sparse.Types
@@ -113,7 +116,19 @@ map_ (v :-> a) b
NEDrop (SS SZ) (NEDrop (SS SZ) a)
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
-fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3
+fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t t)
+ in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3
sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e
@@ -135,7 +150,20 @@ reshape = NEReshape
fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
-> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
-fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3
+fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t1 t1)
+ in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3
fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
-> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))