1 {-# LANGUAGE MultiParamTypeClasses, GADTs, FlexibleContexts, FlexibleInstances, TypeFamilies #-}
2 -----------------------------------------------------------------------------
4 -- Module : GArrowPortShape
6 -- License : public domain
8 -- Maintainer : Adam Megacz <megacz@acm.org>
9 -- Stability : experimental
11 -- | We cannot, at run time, query to find out the input and output
12 -- port types of a GArrowSkeleton since Haskell erases types during
13 -- compilation. Using Data.Typeable is problematic here because
14 -- GAS_comp and GAS_loop{l,r} have an existential type.
16 -- In spite of this, we can determine the "shape" of the ports --
17 -- which ports are of unit type, and which ports must be tensors. A
18 -- GArrowPortShape is a GArrowSkeleton along with this
19 -- information for certain nodes (the inference mechanism below adds
22 module GArrowPortShape (GArrowPortShape(..), PortShape(..), detectShape)
24 import Prelude hiding ( id, (.), lookup )
25 import Control.Category
26 import GHC.HetMet.GArrow
29 import Control.Monad.State
32 -- | Please keep in mind that the "shapes" computed below are simply the
33 -- least-complicated shapes that could possibly work. Just because a
34 -- GArrowPortShape has an input port of shape (x,y)
35 -- doesn't mean it couldn't later be used in a context where its input
36 -- port had shape ((a,b),y)! However, you can be assured that it
37 -- won't be used in a context where the input port has shape ().
39 data PortShape a = PortUnit
40 | PortTensor (PortShape a) (PortShape a)
43 instance Show a => Show (PortShape a) where
45 show (PortTensor p1 p2) = "("++show p1++"*"++show p2++")"
46 show (PortFree x) = show x
48 data GArrowPortShape m s a b =
56 (GArrowSkeleton (GArrowPortShape m s) a b)
59 -- implementation below; none of this is exported
62 type UPort = PortShape UVar
64 instance Unifiable UPort where
65 unify' (PortTensor x1 y1) (PortTensor x2 y2) = mergeU (unify x1 x2) (unify y1 y2)
66 unify' PortUnit PortUnit = emptyUnifier
67 unify' s1 s2 = error $ "Unifiable UPort got impossible unification case: "
69 replace uv prep PortUnit = PortUnit
70 replace uv prep (PortTensor p1 p2) = PortTensor (replace uv prep p1) (replace uv prep p2)
71 replace uv prep (PortFree x) = if x==uv then prep else PortFree x
74 project (PortFree v) = Just v
76 occurrences (PortFree v) = [v]
77 occurrences (PortTensor x y) = occurrences x ++ occurrences y
78 occurrences PortUnit = []
81 type DetectM a = State ((Unifier UPort),[UVar]) a
83 shapes :: GArrowPortShape m UVar a b -> (UPort,UPort)
84 shapes (GASPortPassthrough x y _) = (x,y)
85 shapes (GASPortShapeWrapper x y _) = (x,y)
87 unifyM :: UPort -> UPort -> DetectM ()
88 unifyM p1 p2 = do { (u,vars) <- get
89 ; put (mergeU u $ unify p1 p2 , vars)
92 freshM :: DetectM UVar
93 freshM = do { (u,(v:vars)) <- get
98 -- recursive version of getU
99 getU' :: Unifier UPort -> UPort -> PortShape ()
100 getU' u (PortTensor x y) = PortTensor (getU' u x) (getU' u y)
101 getU' _ PortUnit = PortUnit
102 getU' u x@(PortFree v) = case Unify.getU u v of
103 Nothing -> PortFree () -- or x
104 Just x' -> getU' u x'
106 resolveG :: Unifier UPort -> (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
107 resolveG u (GASPortPassthrough x y m) = GASPortPassthrough (getU' u x) (getU' u y) m
108 resolveG u (GASPortShapeWrapper x y g) = GASPortShapeWrapper (getU' u x) (getU' u y) (resolveG' g)
110 resolveG' :: GArrowSkeleton (GArrowPortShape m UVar) a b ->
111 GArrowSkeleton (GArrowPortShape m ()) a b
112 resolveG' (GAS_id ) = GAS_id
113 resolveG' (GAS_comp f g) = GAS_comp (resolveG' f) (resolveG' g)
114 resolveG' (GAS_first f) = GAS_first (resolveG' f)
115 resolveG' (GAS_second f) = GAS_second (resolveG' f)
116 resolveG' GAS_cancell = GAS_cancell
117 resolveG' GAS_cancelr = GAS_cancelr
118 resolveG' GAS_uncancell = GAS_uncancell
119 resolveG' GAS_uncancelr = GAS_uncancelr
120 resolveG' GAS_drop = GAS_drop
121 resolveG' (GAS_const i) = GAS_const i
122 resolveG' GAS_copy = GAS_copy
123 resolveG' GAS_merge = GAS_merge
124 resolveG' GAS_swap = GAS_swap
125 resolveG' GAS_assoc = GAS_assoc
126 resolveG' GAS_unassoc = GAS_unassoc
127 resolveG' (GAS_loopl f) = GAS_loopl (resolveG' f)
128 resolveG' (GAS_loopr f) = GAS_loopr (resolveG' f)
129 resolveG' (GAS_misc g ) = GAS_misc $ resolveG u g
131 detectShape :: GArrowSkeleton m a b -> GArrowPortShape m () a b
132 detectShape g = runM (detect g)
134 runM :: DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
135 runM f = let s = (emptyUnifier,uvarSupply)
137 (u,_) = execState f s
140 detect :: GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b)
141 detect (GAS_id ) = do { x <- freshM ; return $ GASPortShapeWrapper (PortFree x) (PortFree x) GAS_id }
142 detect (GAS_comp f g) = do { f' <- detect f
144 ; unifyM (snd $ shapes f') (fst $ shapes g')
145 ; return $ GASPortShapeWrapper (fst $ shapes f') (snd $ shapes g') (GAS_comp (GAS_misc f') (GAS_misc g'))
147 detect (GAS_first f) = do { x <- freshM
149 ; return $ GASPortShapeWrapper
150 (PortTensor (fst $ shapes f') (PortFree x))
151 (PortTensor (snd $ shapes f') (PortFree x))
152 (GAS_first (GAS_misc f'))
154 detect (GAS_second f) = do { x <- freshM
156 ; return $ GASPortShapeWrapper
157 (PortTensor (PortFree x) (fst $ shapes f'))
158 (PortTensor (PortFree x) (snd $ shapes f'))
159 (GAS_second (GAS_misc f'))
161 detect GAS_cancell = do { x <- freshM; return$GASPortShapeWrapper (PortTensor PortUnit (PortFree x)) (PortFree x) GAS_cancell }
162 detect GAS_cancelr = do { x <- freshM; return$GASPortShapeWrapper (PortTensor (PortFree x) PortUnit) (PortFree x) GAS_cancelr }
163 detect GAS_uncancell = do { x <- freshM; return$GASPortShapeWrapper (PortFree x) (PortTensor PortUnit (PortFree x)) GAS_uncancell }
164 detect GAS_uncancelr = do { x <- freshM; return$GASPortShapeWrapper (PortFree x) (PortTensor (PortFree x) PortUnit) GAS_uncancelr }
165 detect GAS_drop = do { x <- freshM; return$GASPortShapeWrapper (PortFree x) PortUnit GAS_drop }
166 detect GAS_copy = do { x <- freshM
167 ; return $ GASPortShapeWrapper (PortFree x) (PortTensor (PortFree x) (PortFree x)) GAS_copy }
168 detect GAS_swap = do { x <- freshM
170 ; let x' = PortFree x
171 ; let y' = PortFree y
172 ; return $ GASPortShapeWrapper (PortTensor x' y') (PortTensor y' x') GAS_swap
174 detect GAS_assoc = do { x <- freshM; y <- freshM; z <- freshM
175 ; let x' = PortFree x
176 ; let y' = PortFree y
177 ; let z' = PortFree z
178 ; return $ GASPortShapeWrapper
179 (PortTensor (PortTensor x' y') z')
180 (PortTensor x' (PortTensor y' z'))
183 detect GAS_unassoc = do { x <- freshM; y <- freshM; z <- freshM
184 ; let x' = PortFree x
185 ; let y' = PortFree y
186 ; let z' = PortFree z
187 ; return $ GASPortShapeWrapper
188 (PortTensor x' (PortTensor y' z'))
189 (PortTensor (PortTensor x' y') z')
192 detect (GAS_const i) = do { x <- freshM; return $ GASPortShapeWrapper PortUnit (PortFree x) (GAS_const i) }
194 detect GAS_merge = do { x <- freshM
195 ; return $ GASPortShapeWrapper (PortTensor (PortFree x) (PortFree x)) (PortFree x) GAS_merge }
197 detect (GAS_loopl f) = do { x <- freshM
201 ; unifyM (fst $ shapes f') (PortTensor (PortFree z) (PortFree x))
202 ; unifyM (snd $ shapes f') (PortTensor (PortFree z) (PortFree y))
203 ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopl (GAS_misc f'))
205 detect (GAS_loopr f) = do { x <- freshM
209 ; unifyM (fst $ shapes f') (PortTensor (PortFree x) (PortFree z))
210 ; unifyM (snd $ shapes f') (PortTensor (PortFree y) (PortFree z))
211 ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopr (GAS_misc f'))
214 detect (GAS_misc f) = error "GAS_misc: not implemented"