aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev/Types.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Drev/Types.hs')
-rw-r--r--src/CHAD/Drev/Types.hs8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs
index 367a974..e119de2 100644
--- a/src/CHAD/Drev/Types.hs
+++ b/src/CHAD/Drev/Types.hs
@@ -4,7 +4,8 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Drev.Types where
-import CHAD.AST.Accum
+import Data.Proxy
+
import CHAD.AST.Types
import CHAD.Data
@@ -17,6 +18,7 @@ type family D1 t where
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
+ D1 (TUser t) = TUser t
type family D2 t where
D2 TNil = TNil
@@ -26,6 +28,7 @@ type family D2 t where
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
+ D2 (TUser t) = TUser (UserD2 t)
type family D2s t where
D2s TI32 = TNil
@@ -55,6 +58,7 @@ d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1 (STUser t) = STUser t
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
@@ -74,6 +78,7 @@ d2M (STScal t) = case t of
STF64 -> SMTScal STF64
STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"
+d2M (STUser _) = SMTUser Proxy
d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M
@@ -147,6 +152,7 @@ d1Identity = \case
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
+ STUser{} -> Refl
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl