summaryrefslogtreecommitdiff
path: root/src/CHAD/Types.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-06 22:50:06 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-06 22:50:06 +0200
commit56056c98b2e3dce65a0e42bce0410c083fd1f8be (patch)
tree8db2d1be037f8f980c3d1bf76ff9078048f33d63 /src/CHAD/Types.hs
parent7bd37711ffecb7a0e202ecfd717e3a4cbbe6074f (diff)
WIP mixed static/dynamic sparsitysparse
Diffstat (limited to 'src/CHAD/Types.hs')
-rw-r--r--src/CHAD/Types.hs16
1 files changed, 12 insertions, 4 deletions
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 974669d..83f013d 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
+import AST.Accum
import AST.Types
import Data
@@ -18,11 +19,11 @@ type family D1 t where
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
D2 (TEither a b) = TLEither (D2 a) (D2 b)
D2 (TLEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
- D2 (TArr n t) = TMaybe (TArr n (D2 t))
+ D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
@@ -60,11 +61,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env
d2M :: STy t -> SMTy (D2 t)
d2M STNil = SMTNil
-d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))
+d2M (STPair a b) = SMTPair (d2M a) (d2M b)
d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STMaybe t) = SMTMaybe (d2M t)
-d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))
+d2M (STArr n t) = SMTArr n (d2M t)
d2M (STScal t) = case t of
STI32 -> SMTNil
STI64 -> SMTNil
@@ -116,3 +117,10 @@ chcSetAccum c = c { chcLetArrayAccum = True
indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
+lemZeroInfoScal STI32 = Refl
+lemZeroInfoScal STI64 = Refl
+lemZeroInfoScal STF32 = Refl
+lemZeroInfoScal STF64 = Refl
+lemZeroInfoScal STBool = Refl