summaryrefslogtreecommitdiff
path: root/src/Language.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-09 22:58:52 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-09 22:58:52 +0100
commit34887168c0e2deb549e0e7c77e837ab269d894a2 (patch)
tree8e99e95fbd88569a79a634ed4cc9c787b10d21f1 /src/Language.hs
parent2b1562d33bb9496aa449ef9d52735af0ec61c15c (diff)
Add Custom to Language
Diffstat (limited to 'src/Language.hs')
-rw-r--r--src/Language.hs10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/Language.hs b/src/Language.hs
index 88cb1de..aa55140 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Language (
@@ -13,6 +14,7 @@ module Language (
import Array
import AST
+import CHAD.Types
import Data
import Language.AST
@@ -126,6 +128,14 @@ oper = NEOp
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
+custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
+ -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
+ -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
+ NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
+
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
infix 4 .==