aboutsummaryrefslogtreecommitdiff
path: root/src/Language
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language')
-rw-r--r--src/Language/AST.hs49
1 files changed, 41 insertions, 8 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 325817d..a3b8130 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -4,7 +4,9 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -17,7 +19,7 @@ module Language.AST where
import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
-import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..))
+import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal)
import Array
import AST
@@ -50,7 +52,7 @@ data NExpr env t where
-- array operations
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
- NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
@@ -58,6 +60,15 @@ data NExpr env t where
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+ NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b)
+ -> NExpr env t1
+ -> NExpr env (TArr (S n) t1)
+ -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+ NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2)
+ -> NExpr env (TArr (S n) b)
+ -> NExpr env (TArr n t2)
+ -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
@@ -87,11 +98,16 @@ data NExpr env t where
NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)
-type family Lookup name env where
- Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
- Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
- Lookup name ('(name, t) : env) = t
- Lookup name (_ : env) = Lookup name env
+type Lookup name env = Lookup1 (name == "_") name env
+type family Lookup1 eqblank name env where
+ Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'")
+ Lookup1 False name env = Lookup2 name env
+type family Lookup2 name env where
+ Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
+ Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env
+type family Lookup3 eq t name env where
+ Lookup3 True t _ _ = t
+ Lookup3 False _ name env = Lookup2 name env
type family DropNth i env where
DropNth Z (_ : env) = env
@@ -200,7 +216,7 @@ fromNamedExpr val = \case
NEConstArr n t x -> EConstArr ext n t x
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
@@ -208,6 +224,9 @@ fromNamedExpr val = \case
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
NEReshape n a b -> EReshape ext n (go a) (go b)
+ NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c)
+ NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
NEIdx1 a b -> EIdx1 ext (go a) (go b)
@@ -263,3 +282,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
dropNthW SZ (_ `NPush` _) = WSink
dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
dropNthW _ NTop = error "DropNth: index out of range"
+
+assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r
+assertSymbolNotUnderscore s@SSymbol k =
+ case symbolVal s of
+ "_" -> error "assertSymbolNotUnderscore: was underscore"
+ _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k
+
+assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r
+assertSymbolDistinct s1@SSymbol s2@SSymbol k
+ | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")"
+ | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k
+
+equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r
+equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k