From 8f7445780664d2739c282bd3a83c12caebd9f461 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 24 Apr 2025 21:19:01 +0200 Subject: test: Fix template constraint system --- test/Main.hs | 55 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 15 deletions(-) (limited to 'test') diff --git a/test/Main.hs b/test/Main.hs index 5fa1d46..afbd79b 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} @@ -117,18 +118,22 @@ closeIshT = closeIshT' 1e-5 data a :$ b = a :$ b deriving (Show) ; infixl :$ --- An empty name means "no restrictions". -data TplConstr = C String -- ^ name; @""@ means anonymous - Int -- ^ minimum value to generate +-- | The type index is just a marker that helps typed holes show what (type of) +-- argument this template constraint belongs to. +data TplConstr a = C String -- ^ name; @""@ means anonymous + Int -- ^ minimum value to generate + | NC -- ^ no constraints type family DimNames n where DimNames Z = () - DimNames (S Z) = TplConstr - DimNames (S n) = DimNames n :$ TplConstr + DimNames (S Z) = TplConstr (S Z) + DimNames (S n) = DimNames n :$ TplConstr (S n) type family Tpl t where Tpl (TArr n t) = DimNames n Tpl (TPair a b) = (Tpl a, Tpl b) + Tpl (TScal TI32) = TplConstr TI32 + Tpl (TScal TI64) = TplConstr TI64 -- If you add equations here, don't forget to update genValue! It currently -- just emptyTpl's things out. Tpl _ = () @@ -142,13 +147,17 @@ type family TemplateE env where emptyDimNames :: SNat n -> DimNames n emptyDimNames SZ = () -emptyDimNames (SS SZ) = C "" 0 -emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0 +emptyDimNames (SS SZ) = NC +emptyDimNames (SS n@SS{}) = emptyDimNames n :$ NC emptyTpl :: STy t -> Tpl t emptyTpl (STArr n _) = emptyDimNames n emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) -emptyTpl (STScal _) = () +emptyTpl (STScal STI32) = NC +emptyTpl (STScal STI64) = NC +emptyTpl (STScal STF32) = () +emptyTpl (STScal STF64) = () +emptyTpl (STScal STBool) = () emptyTpl _ = error "too lazy" emptyTemplateE :: SList STy env -> TemplateE env @@ -168,7 +177,8 @@ genShape = \n tpl -> do genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name - genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int + genNamedDim :: TplConstr a -> StateT (Map String Int) Gen Int + genNamedDim NC = genDim 0 genNamedDim (C "" lo) = genDim lo genNamedDim (C name lo) = gets (Map.lookup name) >>= \case Nothing -> do @@ -182,15 +192,17 @@ genShape = \n tpl -> do shapeDiv :: Shape n -> DimNames n -> Int -> Shape n shapeDiv ShNil _ _ = ShNil - shapeDiv (ShNil `ShCons` n) (C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) + shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f)) + shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f) + shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) -genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) +genValue :: forall t. STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) genValue topty tpl = case topty of STNil -> return (Value ()) STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) @@ -202,10 +214,23 @@ genValue topty tpl = case topty of STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) - STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) + STI32 -> genInt + STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" + where + genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t) + genInt = do + let gen lo = Gen.integral (Range.linearFrom 0 lo (max 10 (lo + 10))) + val <- case tpl of + NC -> gen (-10) + C name lo -> gets (Map.lookup name) >>= \case + Nothing -> do + val <- fromIntegral @Int @(Rep t) <$> gen lo + modify (Map.insert name (fromIntegral @(Rep t) @Int val)) + return val + Just val -> return (fromIntegral @Int @(Rep t) val) + return (Value val) genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) genEnv SNil () = return SNil @@ -538,7 +563,7 @@ tests_AD = testGroup "AD" 42 ,adTest "arr-rebind-simple" term_arr_rebind_simple - ,adTest "arr-rebind" term_arr_rebind + ,adTestTp "arr-rebind" (NC :& C "" 0) term_arr_rebind ,adTestGen "neural" Example.neural gen_neural @@ -549,7 +574,7 @@ tests_AD = testGroup "AD" let_ #m (maximum1i #vec) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m - ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec + ,adTestTp "mulmatvec" ((NC :$ C "n" 0) :& C "n" 0) term_mulmatvec ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm -- cgit v1.2.3-70-g09d2