From 57779d4303f377004705c8da06a5ac46177950b2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 4 Nov 2025 23:09:21 +0100 Subject: drevLambda works, TODO D[map] --- src/Language/AST.hs | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) (limited to 'src/Language/AST.hs') diff --git a/src/Language/AST.hs b/src/Language/AST.hs index c9d05c9..a3b8130 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -4,7 +4,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -17,7 +19,7 @@ module Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) import Array import AST @@ -50,7 +52,7 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) @@ -58,7 +60,7 @@ data NExpr env t where NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) - NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b) + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) @@ -96,11 +98,16 @@ data NExpr env t where NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -type family Lookup name env where - Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup name ('(name, t) : env) = t - Lookup name (_ : env) = Lookup name env +type Lookup name env = Lookup1 (name == "_") name env +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env @@ -209,7 +216,7 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) @@ -217,7 +224,7 @@ fromNamedExpr val = \case NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEReshape n a b -> EReshape ext n (go a) (go b) - NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x @@ -275,3 +282,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k -- cgit v1.2.3-70-g09d2