diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-04 23:09:21 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-04 23:09:21 +0100 |
| commit | 57779d4303f377004705c8da06a5ac46177950b2 (patch) | |
| tree | 0407089403d3d5c2de778c1aab7aed8adf2d01c0 /src/Language.hs | |
| parent | 351667a3ff14c96a8dfe3a2f1dd76b6e1a996542 (diff) | |
Diffstat (limited to 'src/Language.hs')
| -rw-r--r-- | src/Language.hs | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/src/Language.hs b/src/Language.hs index 31b4b87..c1a6248 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} @@ -15,6 +16,8 @@ module Language ( Lookup, ) where +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + import Array import AST import AST.Sparse.Types @@ -113,7 +116,19 @@ map_ (v :-> a) b NEDrop (SS SZ) (NEDrop (SS SZ) a) fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) sum1i e = NESum1Inner e @@ -135,7 +150,20 @@ reshape = NEReshape fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) -fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3 +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) |
