newDictFromOld, newDicts, newDictsAtLoc,
newMethod, newMethodWithGivenTy, newOverloadedLit, instOverloadedFun,
- tyVarsOfInst, instLoc, getDictClassTys,
+ tyVarsOfInst, instLoc, getDictClassTys, getFunDeps,
lookupInst, lookupSimpleInst, LookupInstResult(..),
- isDict, isTyVarDict, isStdClassTyVarDict, isMethodFor,
+ isDict, isTyVarDict, isStdClassTyVarDict, isMethodFor, notFunDep,
instBindingRequired, instCanBeGeneralised,
- zonkInst, instToId, instToIdBndr,
+ zonkInst, zonkFunDeps, instToId, instToIdBndr,
InstOrigin(..), InstLoc, pprInstLoc
) where
)
import Bag
import Class ( classInstEnv, Class )
+import FunDeps ( instantiateFdClassTys )
import Id ( Id, idFreeTyVars, idType, mkUserLocal, mkSysLocal )
-import VarSet ( elemVarSet )
import PrelInfo ( isStandardClass, isCcallishClass, isNoDictClass )
import Name ( OccName, Name, mkDictOcc, mkMethodOcc, getOccName )
import PprType ( pprConstraint )
substTy, substTheta, mkTyVarSubst, mkTopTyVarSubst
)
import TyCon ( TyCon )
+import Var ( TyVar )
import VarEnv ( lookupVarEnv, TidyEnv,
lookupSubstEnv, SubstResult(..)
)
-import VarSet ( unionVarSet )
+import VarSet ( elemVarSet, emptyVarSet, unionVarSet )
import TysPrim ( intPrimTy, floatPrimTy, doublePrimTy )
import TysWiredIn ( intDataCon, isIntTy, inIntRange,
floatDataCon, isFloatTy,
TcType -- The type at which the literal is used
InstLoc
+ | FunDep
+ Class -- the class from which this arises
+ [([TcType], [TcType])]
+ InstLoc
+
data OverloadedLit
= OverloadedIntegral Integer -- The number
| OverloadedFractional Rational -- The number
cmpInst (LitInst _ lit1 ty1 _) (LitInst _ lit2 ty2 _)
= (lit1 `cmpOverLit` lit2) `thenCmp` (ty1 `compare` ty2)
+cmpInst (LitInst _ _ _ _) (FunDep _ _ _)
+ = LT
cmpInst (LitInst _ _ _ _) other
= GT
+cmpInst (FunDep clas1 fds1 _) (FunDep clas2 fds2 _)
+ = (clas1 `compare` clas2) `thenCmp` (fds1 `compare` fds2)
+cmpInst (FunDep _ _ _) other
+ = GT
+
cmpOverLit (OverloadedIntegral i1) (OverloadedIntegral i2) = i1 `compare` i2
cmpOverLit (OverloadedFractional f1) (OverloadedFractional f2) = f1 `compare` f2
cmpOverLit (OverloadedIntegral _) (OverloadedFractional _) = LT
instLoc (Dict u clas tys loc) = loc
instLoc (Method u _ _ _ _ loc) = loc
instLoc (LitInst u lit ty loc) = loc
+instLoc (FunDep _ _ loc) = loc
getDictClassTys (Dict u clas tys _) = (clas, tys)
+getFunDeps (FunDep clas fds _) = Just (clas, fds)
+getFunDeps _ = Nothing
+
tyVarsOfInst :: Inst -> TcTyVarSet
tyVarsOfInst (Dict _ _ tys _) = tyVarsOfTypes tys
tyVarsOfInst (Method _ id tys _ _ _) = tyVarsOfTypes tys `unionVarSet` idFreeTyVars id
-- The id might have free type variables; in the case of
-- locally-overloaded class methods, for example
tyVarsOfInst (LitInst _ _ ty _) = tyVarsOfType ty
+tyVarsOfInst (FunDep _ fds _)
+ = foldr unionVarSet emptyVarSet (map tyVarsOfFd fds)
+ where tyVarsOfFd (ts1, ts2) =
+ tyVarsOfTypes ts1 `unionVarSet` tyVarsOfTypes ts1
\end{code}
Predicates
isStdClassTyVarDict (Dict _ clas [ty] _) = isStandardClass clas && isTyVarTy ty
isStdClassTyVarDict other = False
+
+notFunDep :: Inst -> Bool
+notFunDep (FunDep _ _ _) = False
+notFunDep other = True
\end{code}
Two predicates which deal with the case where class constraints don't
instOverloadedFun orig (HsVar v) arg_tys theta tau
= newMethodWithGivenTy orig v arg_tys theta tau `thenNF_Tc` \ inst ->
- returnNF_Tc (HsVar (instToId inst), unitLIE inst)
+ instFunDeps orig theta `thenNF_Tc` \ fds ->
+ returnNF_Tc (HsVar (instToId inst), mkLIE (inst : fds))
+ --returnNF_Tc (HsVar (instToId inst), unitLIE inst)
+
+instFunDeps orig theta
+ = tcGetInstLoc orig `thenNF_Tc` \ loc ->
+ let ifd (clas, tys) = FunDep clas (instantiateFdClassTys clas tys) loc in
+ returnNF_Tc (map ifd theta)
newMethodWithGivenTy orig id tys theta tau
= tcGetInstLoc orig `thenNF_Tc` \ loc ->
instToIdBndr (LitInst u list ty loc)
= mkSysLocal SLIT("lit") u ty
+
+instToIdBndr (FunDep clas fds _)
+ = panic "FunDep escaped!!!"
\end{code}
zonkInst (LitInst u lit ty loc)
= zonkTcType ty `thenNF_Tc` \ new_ty ->
returnNF_Tc (LitInst u lit new_ty loc)
+
+zonkInst (FunDep clas fds loc)
+ = zonkFunDeps fds `thenNF_Tc` \ fds' ->
+ returnNF_Tc (FunDep clas fds' loc)
+
+zonkFunDeps fds = mapNF_Tc zonkFd fds
+ where
+ zonkFd (ts1, ts2)
+ = zonkTcTypes ts1 `thenNF_Tc` \ ts1' ->
+ zonkTcTypes ts2 `thenNF_Tc` \ ts2' ->
+ returnNF_Tc (ts1', ts2')
\end{code}
brackets (interppSP tys),
show_uniq u]
+pprInst (FunDep clas fds loc)
+ = ptext SLIT("fundep!")
+
tidyInst :: TidyEnv -> Inst -> (TidyEnv, Inst)
tidyInst env (LitInst u lit ty loc)
= (env', LitInst u lit ty' loc)
-- Leave theta, tau alone cos we don't print them
where
(env', tys') = tidyOpenTypes env tys
-
+
+-- this case shouldn't arise... (we never print fundeps)
+tidyInst env fd@(FunDep clas fds loc)
+ = (env, fd)
+
tidyInsts env insts = mapAccumL tidyInst env insts
show_uniq u = ifPprDebug (text "{-" <> ppr u <> text "-}")
doubleprim_lit = HsLitOut (HsDoublePrim f) doublePrimTy
double_lit = HsCon doubleDataCon [] [doubleprim_lit]
+-- there are no `instances' of functional dependencies
+
+lookupInst (FunDep _ _ _) = returnNF_Tc NoInstance
+
\end{code}
There is a second, simpler interface, when you want an instance of a
tcGetGlobalTyVars, tcExtendGlobalTyVars
)
import TcSimplify ( tcSimplify, tcSimplifyAndCheck, tcSimplifyToDicts )
+import TcImprove ( tcImprove )
import TcMonoType ( tcHsType, checkSigTyVars,
TcSigInfo(..), tcTySig, maybeSig, sigCtxt
)
-- (must do this before getTyVarsToGen)
checkSigMatch top_lvl binder_names mono_ids tc_ty_sigs `thenTc` \ maybe_sig_theta ->
+ -- IMPROVE the LIE
+ -- Force any unifications dictated by functional dependencies.
+ -- Because unification may happen, it's important that this step
+ -- come before:
+ -- - computing vars over which to quantify
+ -- - zonking the generalized type vars
+ tcImprove lie_req `thenTc_`
+
-- COMPUTE VARIABLES OVER WHICH TO QUANTIFY, namely tyvars_to_gen
-- The tyvars_not_to_gen are free in the environment, and hence
-- candidates for generalisation, but sometimes the monomorphism
--- /dev/null
+\begin{code}
+module TcImprove ( tcImprove ) where
+
+#include "HsVersions.h"
+
+import Type ( tyVarsOfTypes )
+import Class ( classInstEnv, classExtraBigSig )
+import Unify ( matchTys )
+import Subst ( mkSubst, substTy )
+import TcMonad
+import TcType ( zonkTcType, zonkTcTypes )
+import TcUnify ( unifyTauTyLists )
+import Inst ( Inst, LookupInstResult(..),
+ lookupInst, isDict, getDictClassTys, getFunDeps,
+ zonkLIE {- for debugging -} )
+import VarSet ( emptyVarSet )
+import VarEnv ( emptyVarEnv )
+import FunDeps ( instantiateFdClassTys )
+import Bag ( bagToList )
+import Outputable
+import List ( elemIndex )
+import Maybe ( catMaybes )
+\end{code}
+
+Improvement goes here.
+
+\begin{code}
+tcImprove lie
+ = let cfdss = catMaybes (map getFunDeps (bagToList lie)) in
+ iterImprove cfdss
+
+iterImprove cfdss
+ = instImprove cfdss `thenTc` \ change1 ->
+ selfImprove pairImprove cfdss `thenTc` \ change2 ->
+ if change1 || change2 then
+ iterImprove cfdss
+ else
+ returnTc ()
+
+instImprove (cfds@(clas, fds) : cfdss)
+ = instImprove1 cfds ins
+ where ins = classInstEnv clas
+instImprove [] = returnTc False
+
+instImprove1 cfds@(clas, fds1) ((free, ts, _) : ins)
+ = checkFds fds1 free fds2 `thenTc` \ changed ->
+ instImprove1 cfds ins `thenTc` \ rest_changed ->
+ returnTc (changed || rest_changed)
+ where fds2 = instantiateFdClassTys clas ts
+instImprove1 _ _ = returnTc False
+
+selfImprove f [] = returnTc False
+selfImprove f (cfds : cfdss)
+ = mapTc (f cfds) cfdss `thenTc` \ changes ->
+ orTc changes `thenTc` \ changed ->
+ selfImprove f cfdss `thenTc` \ rest_changed ->
+ returnTc (changed || rest_changed)
+
+pairImprove (clas1, fds1) (clas2, fds2)
+ = if clas1 == clas2 then
+ checkFds fds1 emptyVarSet fds2
+ else
+ returnTc False
+
+checkFds [] free [] = returnTc False
+checkFds (fd1 : fd1s) free (fd2 : fd2s) =
+ checkFd fd1 free fd2 `thenTc` \ change ->
+ checkFds fd1s free fd2s `thenTc` \ changes ->
+ returnTc (change || changes)
+--checkFds _ _ = returnTc False
+
+checkFd (t_x, t_y) free (s_x, s_y)
+ -- we need to zonk each time because unification
+ -- may happen at any time
+ = zonkMatchTys t_x free s_x `thenTc` \ msubst ->
+ case msubst of
+ Just subst ->
+ let s_y' = map (substTy (mkSubst emptyVarEnv subst)) s_y in
+ zonkMatchTys t_y free s_y `thenTc` \ msubst2 ->
+ case msubst2 of
+ Just _ ->
+ -- they're the same, nothing changes
+ returnTc False
+ Nothing ->
+ unifyTauTyLists t_y s_y' `thenTc_`
+ -- if we get here, something must have unified
+ returnTc True
+ Nothing ->
+ returnTc False
+
+zonkMatchTys ts1 free ts2
+ = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
+ mapTc zonkTcType ts2 `thenTc` \ ts2' ->
+ --returnTc (ts1' == ts2')
+ case matchTys free ts2' ts1' of
+ Just (subst, []) -> returnTc (Just subst)
+ Nothing -> returnTc Nothing
+
+{-
+instImprove clas fds =
+ pprTrace "class inst env" (ppr (clas, classInstEnv clas)) $
+ zonkFunDeps fds `thenTc` \ fds' ->
+ pprTrace "lIEFDs" (ppr (clas, fds')) $
+ case lookupInstEnvFDs clas fds' of
+ Nothing -> returnTc ()
+ Just (t_y, s_y) ->
+ pprTrace "lIEFDs result" (ppr (t_y, s_y)) $
+ unifyTauTyLists t_y s_y
+
+lookupInstEnvFDs clas fds
+ = find env
+ where
+ env = classInstEnv clas
+ (ctvs, fds, _, _, _, _) = classExtraBigSig clas
+ find [] = Nothing
+ find ((tpl_tyvars, tpl, val) : rest)
+ = let tplx = concatMap (\us -> thingy tpl us ctvs) (map fst fds)
+ tply = concatMap (\vs -> thingy tpl vs ctvs) (map snd fds)
+ in
+ case matchTys tpl_tyvars tplx tysx of
+ Nothing -> find rest
+ Just (tenv, leftovers) ->
+ let subst = mkSubst (tyVarsOfTypes tys) tenv
+ in
+ -- this is the list of things that
+ -- need to be unified
+ Just (map (substTy subst) tply, tysy)
+ tysx = concatMap (\us -> thingy tys us ctvs) (map fst fds)
+ tysy = concatMap (\vs -> thingy tys vs ctvs) (map snd fds)
+ thingy f us ctvs
+ = map (f !!) is
+ where is = map (\u -> let Just i = elemIndex u ctvs in i) us
+-}
+
+{-
+ = let (clas, tys) = getDictClassTys dict
+ in
+ -- first, do instance-based improvement
+ instImprove clas tys `thenTc_`
+ -- OK, now do pairwise stuff
+ mapTc (f clas tys) dicts `thenTc` \ changes ->
+ foldrTc (\a b -> returnTc (a || b)) False changes `thenTc` \ changed ->
+ allDictPairs f dicts `thenTc` \ rest_changed ->
+ returnTc (changed || rest_changed)
+-}
+
+\end{code}
+
+Utilities:
+
+A monadic version of the standard Prelude `or' function.
+\begin{code}
+orTc bs = foldrTc (\a b -> returnTc (a || b)) False bs
+\end{code}
import TcMonad
import Inst ( lookupInst, lookupSimpleInst, LookupInstResult(..),
tyVarsOfInst,
- isDict, isStdClassTyVarDict, isMethodFor,
+ isDict, isStdClassTyVarDict, isMethodFor, notFunDep,
instToId, instBindingRequired, instCanBeGeneralised,
newDictFromOld,
getDictClassTys,
-- Finished
returnTc (mkLIE frees, binds, mkLIE irreds')
where
- wanteds = bagToList wanted_lie
+ -- the idea behind filtering out the dependencies here is that
+ -- they've already served their purpose, and can be reconstructed
+ -- at a later point from the retained class predicates.
+ -- however, there *is* the possibility that a dependency
+ -- out-lives the predicate from which it arose.
+ -- I don't have any examples of this, but if they show up,
+ -- we'd want to consider the possibility of saving the
+ -- dependencies as hidden constraints (i.e. they'd only
+ -- show up in interface files) -- or maybe they'd be useful
+ -- as first class predicates...
+ wanteds = filter notFunDep (bagToList wanted_lie)
try_me inst
-- Does not constrain a local tyvar
returnTc (mkLIE frees, binds)
where
givens = bagToList given_lie
- wanteds = bagToList wanted_lie
+ -- see comment on wanteds in tcSimplify
+ wanteds = filter notFunDep (bagToList wanted_lie)
given_dicts = filter isDict givens
try_me inst
returnTc (binds1 `andMonoBinds` andMonoBindList binds_ambig)
where
- wanteds = bagToList wanted_lie
+ -- see comment on wanteds in tcSimplify
+ wanteds = filter notFunDep (bagToList wanted_lie)
try_me inst = ReduceMe AddToIrreds
d1 `cmp_by_tyvar` d2 = get_tv d1 `compare` get_tv d2
It's better to read it as: "if we know these, then we're going to know these"
\begin{code}
-module FunDeps(oclose, instantiateFundeps, instantiateFdTys, instantiateFdClassTys, pprFundeps) where
+module FunDeps(oclose, instantiateFdClassTys, pprFundeps) where
#include "HsVersions.h"
-import Inst (getDictClassTys)
import Class (classTvsFds)
-import Type (getTyVar_maybe, tyVarsOfType)
import Outputable (interppSP, ptext, empty, hsep, punctuate, comma)
-import UniqSet (elementOfUniqSet, addOneToUniqSet,
- uniqSetToList, unionManyUniqSets)
+import UniqSet (elementOfUniqSet, addOneToUniqSet )
import List (elemIndex)
-import Maybe (catMaybes)
-import FastString
oclose fds vs =
case oclose1 fds vs of
where
(ys', b) = ounion xs ys
--- instantiate fundeps to type variables
-instantiateFundeps dict =
- map (\(xs, ys) -> (unionMap getTyVars xs, unionMap getTyVars ys)) fdtys
- where
- fdtys = instantiateFdTys dict
- getTyVars ty = tyVarsOfType ty
- unionMap f xs = uniqSetToList (unionManyUniqSets (map f xs))
-
--- instantiate fundeps to types
-instantiateFdTys dict = instantiateFdClassTys clas ts
- where (clas, ts) = getDictClassTys dict
instantiateFdClassTys clas ts =
map (lookupInstFundep tyvars ts) fundeps
where