{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Types where import AST.Types import Data type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) D1 (TEither a b) = TEither (D1 a) (D1 b) D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t type family D2 t where D2 TNil = TNil D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t type family D2s t where D2s TI32 = TNil D2s TI64 = TNil D2s TF32 = TScal TF32 D2s TF64 = TScal TF64 D2s TBool = TNil type family D1E env where D1E '[] = '[] D1E (t : env) = D1 t : D1E env type family D2E env where D2E '[] = '[] D2E (t : env) = D2 t : D2E env type family D2AcE env where D2AcE '[] = '[] D2AcE (t : env) = TAccum t : D2AcE env d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STEither a b) = STEither (d1 a) (d1 b) d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil d1e (t `SCons` env) = d1 t `SCons` d1e env d2 :: STy t -> STy (D2 t) d2 STNil = STNil d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) d2 (STMaybe t) = STMaybe (d2 t) d2 (STArr n t) = STArr n (d2 t) d2 (STScal t) = case t of STI32 -> STNil STI64 -> STNil STF32 -> STScal STF32 STF64 -> STScal STF64 STBool -> STNil d2 STAccum{} = error "Accumulators not allowed in input program" d2e :: SList STy env -> SList STy (D2E env) d2e SNil = SNil d2e (t `SCons` ts) = d2 t `SCons` d2e ts d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts data CHADConfig = CHADConfig { -- | D[let] will bind variables containing arrays in accumulator mode. chcLetArrayAccum :: Bool , -- | D[case] will bind variables containing arrays in accumulator mode. chcCaseArrayAccum :: Bool , -- | Introduce top-level arguments containing arrays in accumulator mode. chcArgArrayAccum :: Bool } deriving (Show) defaultConfig :: CHADConfig defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False , chcArgArrayAccum = False } chcSetAccum :: CHADConfig -> CHADConfig chcSetAccum c = c { chcLetArrayAccum = True , chcCaseArrayAccum = True , chcArgArrayAccum = True } ------------------------------------ LEMMAS ------------------------------------ indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) indexTupD1Id SZ = Refl indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl