{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Language (
  fromNamed,
  NExpr,
  module Language,
  Lookup,
) where

import Array
import AST
import Data
import Language.AST


data a :-> b = a :-> b
  deriving (Show)
infixr 0 :->


body :: NExpr env t -> NFun env env t
body = NBody

lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam

inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
inline = inlineNFun

(.$) :: SList f list -> f a -> SList f (a : list)
(.$) = flip SCons


let_ :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
let_ = NELet

pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
pair = NEPair

fst_ :: NExpr env (TPair a b) -> NExpr env a
fst_ = NEFst

snd_ :: NExpr env (TPair a b) -> NExpr env b
snd_ = NESnd

nil :: NExpr env TNil
nil = NENil

inl :: STy b -> NExpr env a -> NExpr env (TEither a b)
inl = NEInl

inr :: STy a -> NExpr env b -> NExpr env (TEither a b)
inr = NEInr

case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2

constArr_ :: (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
constArr_ x =
  let ty = knownScalTy
  in case scalRepIsShow ty of
       Dict -> NEConstArr knownNat ty x

build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))

build2 :: NExpr env TIx -> NExpr env TIx
       -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
       -> NExpr env (TArr (S (S Z)) t)
build2 a1 a2 (v1 :-> v2 :-> b) =
  NEBuild (SS (SS SZ))
          (pair (pair nil a1) a2)
          #idx
          (let_ v1 (snd_ (fst_ #idx)) $
           let_ v2 (NEDrop SZ (snd_ #idx)) $
             NEDrop (SS (SS SZ)) b)

build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b

fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3

sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e

unit :: NExpr env t -> NExpr env (TArr Z t)
unit = NEUnit

replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
replicate1i n a = NEReplicate1Inner n a

const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
  let ty = knownScalTy
  in case scalRepIsShow ty of
       Dict -> NEConst ty x

idx0 :: NExpr env (TArr Z t) -> NExpr env t
idx0 = NEIdx0

(.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
(.!) = NEIdx1
infixl 9 .!

(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
(!) = NEIdx
infixl 9 !

shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape

oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp

error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s

(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
infix 4 .==

(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .< b = oper (OLt knownScalTy) (pair a b)
infix 4 .<

(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>) = flip (.<)
infix 4 .>

(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .<= b = oper (OLe knownScalTy) (pair a b)
infix 4 .<=

(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>=) = flip (.<=)
infix 4 .>=

not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot

-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)