diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
commit | a00234388d1b4e14481067d030bf90031258b756 (patch) | |
tree | 501b6778fc5779ce220aba1e22f56ae60f68d970 /src/AST | |
parent | 7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff) |
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/AST')
-rw-r--r-- | src/AST/Types.hs | 51 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 16 |
2 files changed, 59 insertions, 8 deletions
diff --git a/src/AST/Types.hs b/src/AST/Types.hs index 0b41671..217b2f5 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -8,7 +9,9 @@ module AST.Types where import Data.Int (Int32, Int64) +import Data.GADT.Show import Data.Kind (Type) +import Data.Some import Data.Type.Equality import Data @@ -54,6 +57,8 @@ instance TestEquality STy where testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl testEquality STAccum{} _ = Nothing +instance GShow STy where gshowsPrec = defaultGshowsPrec + data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 @@ -70,6 +75,8 @@ instance TestEquality SScalTy where testEquality STBool STBool = Just Refl testEquality _ _ = Nothing +instance GShow SScalTy where gshowsPrec = defaultGshowsPrec + scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) scalRepIsShow STI32 = Dict scalRepIsShow STI64 = Dict @@ -82,6 +89,50 @@ type TIx = TScal TI64 tIx :: STy TIx tIx = STScal STI64 +unSTy :: STy t -> Ty +unSTy = \case + STNil -> TNil + STPair a b -> TPair (unSTy a) (unSTy b) + STEither a b -> TEither (unSTy a) (unSTy b) + STMaybe t -> TMaybe (unSTy t) + STArr n t -> TArr (unSNat n) (unSTy t) + STScal t -> TScal (unSScalTy t) + STAccum t -> TAccum (unSTy t) + +unSEnv :: SList STy env -> [Ty] +unSEnv SNil = [] +unSEnv (SCons t l) = unSTy t : unSEnv l + +unSScalTy :: SScalTy t -> ScalTy +unSScalTy = \case + STI32 -> TI32 + STI64 -> TI64 + STF32 -> TF32 + STF64 -> TF64 + STBool -> TBool + +reSTy :: Ty -> Some STy +reSTy = \case + TNil -> Some STNil + TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b' + TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b' + TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t' + TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t' + TScal t | Some t' <- reSScalTy t -> Some $ STScal t' + TAccum t | Some t' <- reSTy t -> Some $ STAccum t' + +reSEnv :: [Ty] -> Some (SList STy) +reSEnv [] = Some SNil +reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env) + +reSScalTy :: ScalTy -> Some SScalTy +reSScalTy = \case + TI32 -> Some STI32 + TI64 -> Some STI64 + TF32 -> Some STF32 + TF64 -> Some STF64 + TBool -> Some STBool + type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index b30f7a0..0da1afc 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -51,8 +51,8 @@ zero STNil = ENil ext zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr SZ t) = EUnit ext (zero t) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty") +zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t)) +zero (STArr n t) = ENothing ext (STArr n (d2 t)) zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -84,8 +84,7 @@ plus (STMaybe t) a b = plusSparse (d2 t) a b $ plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) plus (STArr n t) a b = - ELet ext a $ - ELet ext (weakenExpr WSink b) $ + plusSparse (STArr n (d2 t)) a b $ eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) (EVar ext (STArr n (d2 t)) IZ) (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) @@ -131,7 +130,8 @@ onehot typ topprj idx arg = case (typ, topprj) of (STArr n t1, SAPArrIdx prj _) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ - EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (zero t1) + EJust ext $ + EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (zero t1) |