{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
module AST.UnMonoid (unMonoid, zero, plus) where

import AST
import CHAD.Types
import Data


unMonoid :: Ex env t -> Ex env t
unMonoid = \case
  EZero _ t -> zero t
  EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
  EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)

  EVar _ t i -> EVar ext t i
  ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
  EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
  EFst _ e -> EFst ext (unMonoid e)
  ESnd _ e -> ESnd ext (unMonoid e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext t (unMonoid e)
  EInr _ t e -> EInr ext t (unMonoid e)
  ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
  ENothing _ t -> ENothing ext t
  EJust _ e -> EJust ext (unMonoid e)
  EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
  EConstArr _ n t x -> EConstArr ext n t x
  EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
  EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
  ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
  EUnit _ e -> EUnit ext (unMonoid e)
  EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
  EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
  EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
  EConst _ t x -> EConst ext t x
  EIdx0 _ e -> EIdx0 ext (unMonoid e)
  EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
  EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
  EShape _ e -> EShape ext (unMonoid e)
  EOp _ op e -> EOp ext op (unMonoid e)
  ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
  EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
  EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
  EError _ t s -> EError ext t s

zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
zero (STArr SZ t) = EUnit ext (zero t)
zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty")
zero (STScal t) = case t of
  STI32 -> ENil ext
  STI64 -> ENil ext
  STF32 -> EConst ext STF32 0.0
  STF64 -> EConst ext STF64 0.0
  STBool -> ENil ext
zero STAccum{} = error "Accumulators not allowed in input program"

plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus STNil _ _ = ENil ext
plus (STPair t1 t2) a b =
  let t = STPair (d2 t1) (d2 t2)
  in plusSparse t a b $
       EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
                          (EFst ext (EVar ext t IZ)))
                 (plus t2 (ESnd ext (EVar ext t (IS IZ)))
                          (ESnd ext (EVar ext t IZ)))
plus (STEither t1 t2) a b =
  let t = STEither (d2 t1) (d2 t2)
  in plusSparse t a b $
       ECase ext (EVar ext t (IS IZ))
         (ECase ext (EVar ext t (IS IZ))
           (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
           (EError ext t "plus l+r"))
         (ECase ext (EVar ext t (IS IZ))
           (EError ext t "plus r+l")
           (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
plus (STMaybe t) a b =
  plusSparse (d2 t) a b $
    plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
plus (STArr n t) a b =
  ELet ext a $
  ELet ext (weakenExpr WSink b) $
    eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
        (EVar ext (STArr n (d2 t)) IZ)
        (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
             (EVar ext (STArr n (d2 t)) (IS IZ))
             (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
               (EVar ext (STArr n (d2 t)) (IS IZ))
               (EVar ext (STArr n (d2 t)) IZ)))
plus (STScal t) a b = case t of
  STI32 -> ENil ext
  STI64 -> ENil ext
  STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
  STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
  STBool -> ENil ext
plus STAccum{} _ _ = error "Accumulators not allowed in input program"

plusSparse :: STy a
           -> Ex env (TMaybe a) -> Ex env (TMaybe a)
           -> Ex (a : a : env) a
           -> Ex env (TMaybe a)
plusSparse t a b adder =
  ELet ext b $
    EMaybe ext
      (EVar ext (STMaybe t) IZ)
      (EJust ext
        (EMaybe ext
          (EVar ext t IZ)
          (weakenExpr (WCopy (WCopy WSink)) adder)
          (EVar ext (STMaybe t) (IS IZ))))
      (weakenExpr WSink a)

onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
onehot typ topprj idx arg = case (typ, topprj) of
  (_, SAPHere) -> arg

  (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
  (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))

  (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
  (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))

  (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)

  (STArr n t1, SAPArrIdx prj _) ->
    let tidx = tTup (sreplicate n tIx)
    in ELet ext idx $
         EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
           eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
             (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
             (zero t1)