X-Git-Url: http://git.megacz.com/?p=coq-hetmet.git;a=blobdiff_plain;f=examples%2FGArrowPortShape.hs;h=45cbdb248e8ee419d49a3354cad1b049ad045ce6;hp=8e5078bd5e1b1a5c8fecab9e0a9bf6ea8b1625c5;hb=ec996e8cb550676d89d187061db7d018af9ec88d;hpb=4f4443fab27bf4724b7a1b5a92a48d5a58c440d7 diff --git a/examples/GArrowPortShape.hs b/examples/GArrowPortShape.hs index 8e5078b..45cbdb2 100644 --- a/examples/GArrowPortShape.hs +++ b/examples/GArrowPortShape.hs @@ -19,11 +19,11 @@ -- information for certain nodes (the inference mechanism below adds -- it on every node). -- -module GArrowPortShape (GArrowPortShape(..), PortShape(..)) +module GArrowPortShape (GArrowPortShape(..), PortShape(..), detectShape, Detect(..), DetectM, freshM) where import Prelude hiding ( id, (.), lookup ) import Control.Category -import GHC.HetMet.GArrow +import Control.GArrow import Unify import GArrowSkeleton import Control.Monad.State @@ -40,6 +40,11 @@ data PortShape a = PortUnit | PortTensor (PortShape a) (PortShape a) | PortFree a +instance Show a => Show (PortShape a) where + show PortUnit = "U" + show (PortTensor p1 p2) = "("++show p1++"*"++show p2++")" + show (PortFree x) = show x + data GArrowPortShape m s a b = GASPortPassthrough (PortShape s) @@ -58,7 +63,13 @@ type UPort = PortShape UVar instance Unifiable UPort where unify' (PortTensor x1 y1) (PortTensor x2 y2) = mergeU (unify x1 x2) (unify y1 y2) - unify' _ _ = error "impossible" + unify' PortUnit PortUnit = emptyUnifier + unify' s1 s2 = error $ "Unifiable UPort got impossible unification case: " + + replace uv prep PortUnit = PortUnit + replace uv prep (PortTensor p1 p2) = PortTensor (replace uv prep p1) (replace uv prep p2) + replace uv prep (PortFree x) = if x==uv then prep else PortFree x + inject = PortFree project (PortFree v) = Just v project _ = Nothing @@ -117,18 +128,24 @@ resolveG u (GASPortShapeWrapper x y g) = GASPortShapeWrapper (getU' u x) (getU' resolveG' (GAS_loopr f) = GAS_loopr (resolveG' f) resolveG' (GAS_misc g ) = GAS_misc $ resolveG u g -runM :: DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b +detectShape :: Detect m => GArrowSkeleton m a b -> GArrowPortShape m () a b +detectShape g = runM (detect g) + +runM :: Detect m => DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b runM f = let s = (emptyUnifier,uvarSupply) g = evalState f s (u,_) = execState f s in resolveG u g -detect :: GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b) +class Detect m where + detect' :: m x y -> DetectM (GArrowPortShape m UVar x y) + +detect :: Detect m => GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b) detect (GAS_id ) = do { x <- freshM ; return $ GASPortShapeWrapper (PortFree x) (PortFree x) GAS_id } -detect (GAS_comp g f) = do { f' <- detect f +detect (GAS_comp f g) = do { f' <- detect f ; g' <- detect g ; unifyM (snd $ shapes f') (fst $ shapes g') - ; return $ GASPortShapeWrapper (fst $ shapes f') (snd $ shapes g') (GAS_comp (GAS_misc g') (GAS_misc f')) + ; return $ GASPortShapeWrapper (fst $ shapes f') (snd $ shapes g') (GAS_comp (GAS_misc f') (GAS_misc g')) } detect (GAS_first f) = do { x <- freshM ; f' <- detect f @@ -176,10 +193,26 @@ detect GAS_unassoc = do { x <- freshM; y <- freshM; z <- freshM GAS_unassoc } detect (GAS_const i) = do { x <- freshM; return $ GASPortShapeWrapper PortUnit (PortFree x) (GAS_const i) } + detect GAS_merge = do { x <- freshM ; return $ GASPortShapeWrapper (PortTensor (PortFree x) (PortFree x)) (PortFree x) GAS_merge } -detect (GAS_loopl f) = error "not implemented" -detect (GAS_loopr f) = error "not implemented" -detect (GAS_misc f) = error "not implemented" +detect (GAS_loopl f) = do { x <- freshM + ; y <- freshM + ; z <- freshM + ; f' <- detect f + ; unifyM (fst $ shapes f') (PortTensor (PortFree z) (PortFree x)) + ; unifyM (snd $ shapes f') (PortTensor (PortFree z) (PortFree y)) + ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopl (GAS_misc f')) + } +detect (GAS_loopr f) = do { x <- freshM + ; y <- freshM + ; z <- freshM + ; f' <- detect f + ; unifyM (fst $ shapes f') (PortTensor (PortFree x) (PortFree z)) + ; unifyM (snd $ shapes f') (PortTensor (PortFree y) (PortFree z)) + ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopr (GAS_misc f')) + } + +detect (GAS_misc f) = detect' f