diff options
Diffstat (limited to 'src/CHAD/Drev/Types.hs')
| -rw-r--r-- | src/CHAD/Drev/Types.hs | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs index 367a974..e119de2 100644 --- a/src/CHAD/Drev/Types.hs +++ b/src/CHAD/Drev/Types.hs @@ -4,7 +4,8 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Drev.Types where -import CHAD.AST.Accum +import Data.Proxy + import CHAD.AST.Types import CHAD.Data @@ -17,6 +18,7 @@ type family D1 t where D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t + D1 (TUser t) = TUser t type family D2 t where D2 TNil = TNil @@ -26,6 +28,7 @@ type family D2 t where D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t + D2 (TUser t) = TUser (UserD2 t) type family D2s t where D2s TI32 = TNil @@ -55,6 +58,7 @@ d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1 (STUser t) = STUser t d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil @@ -74,6 +78,7 @@ d2M (STScal t) = case t of STF64 -> SMTScal STF64 STBool -> SMTNil d2M STAccum{} = error "Accumulators not allowed in input program" +d2M (STUser _) = SMTUser Proxy d2 :: STy t -> STy (D2 t) d2 = fromSMTy . d2M @@ -147,6 +152,7 @@ d1Identity = \case STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" + STUser{} -> Refl d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl |
