Skip to content

Make adding missing constraint work in presence of 'forall' (fixes #1164) #1177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

11 changes: 10 additions & 1 deletion ghcide/src/Development/IDE/GHC/Compat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ module Development.IDE.GHC.Compat(
applyPluginsParsedResultAction,
module Compat.HieTypes,
module Compat.HieUtils,

dropForAll
) where

#if MIN_GHC_API_VERSION(8,10,0)
Expand Down Expand Up @@ -283,3 +283,12 @@ pattern ExposePackage s a mr <- DynFlags.ExposePackage s a _ mr
#else
pattern ExposePackage s a mr = DynFlags.ExposePackage s a mr
#endif

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like splitLHsForAllTy no longer exists since ghc 8.10 and there's now just splitLHsForAllTyInvis.
Is this way of handling this situation ok?

-- | Take AST representation of type signature and drop `forall` part from it (if any), returning just type's body
dropForAll :: LHsType pass -> LHsType pass
#if MIN_GHC_API_VERSION(8,10,0)
dropForAll = snd . GHC.splitLHsForAllTyInvis
#else
dropForAll = snd . GHC.splitLHsForAllTy
#endif

7 changes: 4 additions & 3 deletions ghcide/src/Development/IDE/Plugin/CodeAction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -803,12 +803,13 @@ suggestFunctionConstraint ParsedModule{pm_parsed_source = L _ HsModule{hsmodDecl
| L _ (SigD _ (TypeSig _ identifiers (HsWC _ (HsIB _ locatedType)))) <- hsmodDecls
, any (`isSameName` T.unpack typeSignatureName) $ fmap unLoc identifiers
]
srcSpanToRange $ case splitLHsQualTy locatedType of
let typeBody = dropForAll locatedType
srcSpanToRange $ case splitLHsQualTy typeBody of
(L contextSrcSpan _ , _) ->
if isGoodSrcSpan contextSrcSpan
then contextSrcSpan -- The type signature has explicit context
else -- No explicit context, return SrcSpan at the start of type sig where we can write context
let start = srcSpanStart $ getLoc locatedType in mkSrcSpan start start
else -- No explicit context, return SrcSpan at the start of type (after a potential `forall`)
let start = srcSpanStart $ getLoc typeBody in mkSrcSpan start start

isSameName :: IdP GhcPs -> String -> Bool
isSameName x name = showSDocUnsafe (ppr x) == name
Expand Down
41 changes: 39 additions & 2 deletions ghcide/test/exe/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,28 @@ addFunctionConstraintTests = let
, "eq x y = x == y"
]

missingConstraintWithForAllSourceCode :: T.Text -> T.Text
missingConstraintWithForAllSourceCode constraint =
T.unlines
[ "{-# LANGUAGE ExplicitForAll #-}"
, "module Testing where"
, ""
, "eq :: forall a. " <> constraint <> "a -> a -> Bool"
, "eq x y = x == y"
]

incompleteConstraintWithForAllSourceCode :: T.Text -> T.Text
incompleteConstraintWithForAllSourceCode constraint =
T.unlines
[ "{-# LANGUAGE ExplicitForAll #-}"
, "module Testing where"
, ""
, "data Pair a b = Pair a b"
, ""
, "eq :: " <> constraint <> " => Pair a b -> Pair a b -> Bool"
, "eq (Pair x y) (Pair x' y') = x == x' && y == y'"
]

incompleteConstraintSourceCode :: T.Text -> T.Text
incompleteConstraintSourceCode constraint =
T.unlines
Expand Down Expand Up @@ -1978,8 +2000,8 @@ addFunctionConstraintTests = let
, "eq (Pair x y) (Pair x' y') = x == x' && y == y'"
]

check :: T.Text -> T.Text -> T.Text -> TestTree
check actionTitle originalCode expectedCode = testSession (T.unpack actionTitle) $ do
check :: String -> T.Text -> T.Text -> T.Text -> TestTree
check testName actionTitle originalCode expectedCode = testSession testName $ do
doc <- createDoc "Testing.hs" "haskell" originalCode
_ <- waitForDiagnostics
actionsOrCommands <- getCodeActions doc (Range (Position 6 0) (Position 6 maxBound))
Expand All @@ -1990,22 +2012,37 @@ addFunctionConstraintTests = let

in testGroup "add function constraint"
[ check
"no preexisting constraint"
"Add `Eq a` to the context of the type signature for `eq`"
(missingConstraintSourceCode "")
(missingConstraintSourceCode "Eq a => ")
, check
"no preexisting constraint, with forall"
"Add `Eq a` to the context of the type signature for `eq`"
(missingConstraintWithForAllSourceCode "")
(missingConstraintWithForAllSourceCode "Eq a => ")
, check
"preexisting constraint, no parenthesis"
"Add `Eq b` to the context of the type signature for `eq`"
(incompleteConstraintSourceCode "Eq a")
(incompleteConstraintSourceCode "(Eq a, Eq b)")
, check
"preexisting constraints in parenthesis"
"Add `Eq c` to the context of the type signature for `eq`"
(incompleteConstraintSourceCode2 "(Eq a, Eq b)")
(incompleteConstraintSourceCode2 "(Eq a, Eq b, Eq c)")
, check
"preexisting constraints with forall"
"Add `Eq b` to the context of the type signature for `eq`"
(incompleteConstraintWithForAllSourceCode "Eq a")
(incompleteConstraintWithForAllSourceCode "(Eq a, Eq b)")
, check
"preexisting constraint, with extra spaces in context"
"Add `Eq b` to the context of the type signature for `eq`"
(incompleteConstraintSourceCodeWithExtraCharsInContext "( Eq a )")
(incompleteConstraintSourceCodeWithExtraCharsInContext "(Eq a, Eq b)")
, check
"preexisting constraint, with newlines in type signature"
"Add `Eq b` to the context of the type signature for `eq`"
(incompleteConstraintSourceCodeWithNewlinesInTypeSignature "(Eq a)")
(incompleteConstraintSourceCodeWithNewlinesInTypeSignature "(Eq a, Eq b)")
Expand Down