aboutsummaryrefslogtreecommitdiff
path: root/src/Language/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
commit57779d4303f377004705c8da06a5ac46177950b2 (patch)
tree0407089403d3d5c2de778c1aab7aed8adf2d01c0 /src/Language/AST.hs
parent351667a3ff14c96a8dfe3a2f1dd76b6e1a996542 (diff)
drevLambda works, TODO D[map]HEADmaster
Diffstat (limited to 'src/Language/AST.hs')
-rw-r--r--src/Language/AST.hs41
1 files changed, 31 insertions, 10 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index c9d05c9..a3b8130 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -4,7 +4,9 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -17,7 +19,7 @@ module Language.AST where
import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
-import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..))
+import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal)
import Array
import AST
@@ -50,7 +52,7 @@ data NExpr env t where
-- array operations
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
- NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
@@ -58,7 +60,7 @@ data NExpr env t where
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
- NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b)
+ NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b)
-> NExpr env t1
-> NExpr env (TArr (S n) t1)
-> NExpr env (TPair (TArr n t1) (TArr (S n) b))
@@ -96,11 +98,16 @@ data NExpr env t where
NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)
-type family Lookup name env where
- Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
- Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
- Lookup name ('(name, t) : env) = t
- Lookup name (_ : env) = Lookup name env
+type Lookup name env = Lookup1 (name == "_") name env
+type family Lookup1 eqblank name env where
+ Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'")
+ Lookup1 False name env = Lookup2 name env
+type family Lookup2 name env where
+ Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
+ Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env
+type family Lookup3 eq t name env where
+ Lookup3 True t _ _ = t
+ Lookup3 False _ name env = Lookup2 name env
type family DropNth i env where
DropNth Z (_ : env) = env
@@ -209,7 +216,7 @@ fromNamedExpr val = \case
NEConstArr n t x -> EConstArr ext n t x
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
@@ -217,7 +224,7 @@ fromNamedExpr val = \case
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
NEReshape n a b -> EReshape ext n (go a) (go b)
- NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c)
NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
NEConst t x -> EConst ext t x
@@ -275,3 +282,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
dropNthW SZ (_ `NPush` _) = WSink
dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
dropNthW _ NTop = error "DropNth: index out of range"
+
+assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r
+assertSymbolNotUnderscore s@SSymbol k =
+ case symbolVal s of
+ "_" -> error "assertSymbolNotUnderscore: was underscore"
+ _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k
+
+assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r
+assertSymbolDistinct s1@SSymbol s2@SSymbol k
+ | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")"
+ | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k
+
+equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r
+equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k