aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev/Types.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Drev/Types.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Drev/Types.hs')
-rw-r--r--src/CHAD/Drev/Types.hs153
1 files changed, 153 insertions, 0 deletions
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs
new file mode 100644
index 0000000..367a974
--- /dev/null
+++ b/src/CHAD/Drev/Types.hs
@@ -0,0 +1,153 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Drev.Types where
+
+import CHAD.AST.Accum
+import CHAD.AST.Types
+import CHAD.Data
+
+
+type family D1 t where
+ D1 TNil = TNil
+ D1 (TPair a b) = TPair (D1 a) (D1 b)
+ D1 (TEither a b) = TEither (D1 a) (D1 b)
+ D1 (TLEither a b) = TLEither (D1 a) (D1 b)
+ D1 (TMaybe a) = TMaybe (D1 a)
+ D1 (TArr n t) = TArr n (D1 t)
+ D1 (TScal t) = TScal t
+
+type family D2 t where
+ D2 TNil = TNil
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
+ D2 (TEither a b) = TLEither (D2 a) (D2 b)
+ D2 (TLEither a b) = TLEither (D2 a) (D2 b)
+ D2 (TMaybe t) = TMaybe (D2 t)
+ D2 (TArr n t) = TArr n (D2 t)
+ D2 (TScal t) = D2s t
+
+type family D2s t where
+ D2s TI32 = TNil
+ D2s TI64 = TNil
+ D2s TF32 = TScal TF32
+ D2s TF64 = TScal TF64
+ D2s TBool = TNil
+
+type family D1E env where
+ D1E '[] = '[]
+ D1E (t : env) = D1 t : D1E env
+
+type family D2E env where
+ D2E '[] = '[]
+ D2E (t : env) = D2 t : D2E env
+
+type family D2AcE env where
+ D2AcE '[] = '[]
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
+
+d1 :: STy t -> STy (D1 t)
+d1 STNil = STNil
+d1 (STPair a b) = STPair (d1 a) (d1 b)
+d1 (STEither a b) = STEither (d1 a) (d1 b)
+d1 (STLEither a b) = STLEither (d1 a) (d1 b)
+d1 (STMaybe t) = STMaybe (d1 t)
+d1 (STArr n t) = STArr n (d1 t)
+d1 (STScal t) = STScal t
+d1 STAccum{} = error "Accumulators not allowed in input program"
+
+d1e :: SList STy env -> SList STy (D1E env)
+d1e SNil = SNil
+d1e (t `SCons` env) = d1 t `SCons` d1e env
+
+d2M :: STy t -> SMTy (D2 t)
+d2M STNil = SMTNil
+d2M (STPair a b) = SMTPair (d2M a) (d2M b)
+d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STMaybe t) = SMTMaybe (d2M t)
+d2M (STArr n t) = SMTArr n (d2M t)
+d2M (STScal t) = case t of
+ STI32 -> SMTNil
+ STI64 -> SMTNil
+ STF32 -> SMTScal STF32
+ STF64 -> SMTScal STF64
+ STBool -> SMTNil
+d2M STAccum{} = error "Accumulators not allowed in input program"
+
+d2 :: STy t -> STy (D2 t)
+d2 = fromSMTy . d2M
+
+d2eM :: SList STy env -> SList SMTy (D2E env)
+d2eM SNil = SNil
+d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts
+
+d2e :: SList STy env -> SList STy (D2E env)
+d2e = slistMap fromSMTy . d2eM
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts
+
+
+data CHADConfig = CHADConfig
+ { -- | D[let] will bind variables containing arrays in accumulator mode.
+ chcLetArrayAccum :: Bool
+ , -- | D[case] will bind variables containing arrays in accumulator mode.
+ chcCaseArrayAccum :: Bool
+ , -- | Introduce top-level arguments containing arrays in accumulator mode.
+ chcArgArrayAccum :: Bool
+ , -- | Place with-blocks around array variable scopes, and redirect accumulations there.
+ chcSmartWith :: Bool
+ }
+ deriving (Show)
+
+defaultConfig :: CHADConfig
+defaultConfig = CHADConfig
+ { chcLetArrayAccum = False
+ , chcCaseArrayAccum = False
+ , chcArgArrayAccum = False
+ , chcSmartWith = False
+ }
+
+chcSetAccum :: CHADConfig -> CHADConfig
+chcSetAccum c = c { chcLetArrayAccum = True
+ , chcCaseArrayAccum = True
+ , chcArgArrayAccum = True
+ , chcSmartWith = True }
+
+
+------------------------------------ LEMMAS ------------------------------------
+
+indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
+indexTupD1Id SZ = Refl
+indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
+lemZeroInfoScal STI32 = Refl
+lemZeroInfoScal STI64 = Refl
+lemZeroInfoScal STF32 = Refl
+lemZeroInfoScal STF64 = Refl
+lemZeroInfoScal STBool = Refl
+
+lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
+lemDeepZeroInfoScal STI32 = Refl
+lemDeepZeroInfoScal STI64 = Refl
+lemDeepZeroInfoScal STF32 = Refl
+lemDeepZeroInfoScal STF64 = Refl
+lemDeepZeroInfoScal STBool = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl