summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-16 23:21:55 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-16 23:21:55 +0200
commit2b1a40b5933b8b0dceaae744e5b70cb604822c9d (patch)
tree652d6d88efd2b0b4502819297333305cec5242c4 /src/AST
parenteed0f2999d6f6c8485ef53deb38f9d0a67b4f88e (diff)
CHAD.hs compiles
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Accum.hs36
-rw-r--r--src/AST/UnMonoid.hs2
2 files changed, 27 insertions, 11 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 1101cc0..158b4d9 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
@@ -32,21 +33,36 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
-type family AcIdx p t where
- AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b)
- AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b)
- AcIdx (APLeft p) (TLEither a b) = AcIdx p a
- AcIdx (APRight p) (TLEither a b) = AcIdx p b
- AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) =
+type data StillDense = AI_D | AI_S
+data SStillDense dense where
+ SAI_D :: SStillDense AI_D
+ SAI_S :: SStillDense AI_S
+deriving instance Show (SStillDense dense)
+
+type family AcIdx dense p t where
+ AcIdx dense APHere t = TNil
+ AcIdx AI_D (APFst p) (TPair a b) = AcIdx AI_D p a
+ AcIdx AI_D (APSnd p) (TPair a b) = AcIdx AI_D p b
+ AcIdx AI_S (APFst p) (TPair a b) = TPair (AcIdx AI_S p a) (ZeroInfo b)
+ AcIdx AI_S (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AI_S p b)
+ AcIdx dense (APLeft p) (TLEither a b) = AcIdx AI_S p a
+ AcIdx dense (APRight p) (TLEither a b) = AcIdx AI_S p b
+ AcIdx dense (APJust p) (TMaybe a) = AcIdx AI_S p a
+ AcIdx AI_D (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx AI_D p a)
+ AcIdx AI_S (APArrIdx p) (TArr n a) =
-- ((index, shapes info), recursive info)
TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
- (AcIdx p a)
- -- AcIdx (APArrSlice m) (TArr n a) =
+ (AcIdx AI_S p a)
+ -- AcIdx AI_D (APArrSlice m) (TArr n a) =
+ -- -- index
+ -- Tup (Replicate m TIx)
+ -- AcIdx AI_S (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+type AcIdxD p t = AcIdx AI_D p t
+type AcIdxS p t = AcIdx AI_S p t
+
acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index ac4d733..389dd5a 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -105,7 +105,7 @@ plus (SMTArr _ t) a b =
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
ELet ext arg $