Skip to content

Simplify tactics state structure #1449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Data.Aeson
import Data.Bifunctor (Bifunctor (bimap))
import Data.Bool (bool)
import Data.Data (Data)
import Data.Foldable (for_)
import Data.Generics.Aliases (mkQ)
import Data.Generics.Schemes (everything)
import Data.Maybe
Expand Down Expand Up @@ -144,6 +145,7 @@ mkWorkspaceEdits
-> RunTacticResults
-> Either ResponseError (Maybe WorkspaceEdit)
mkWorkspaceEdits span dflags ccs uri pm rtr = do
for_ (rtr_other_solns rtr) $ traceMX "other solution"
let g = graftHole (RealSrcSpan span) rtr
response = transform dflags ccs uri g pm
in case response of
Expand Down
75 changes: 31 additions & 44 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ module Ide.Plugin.Tactic.CodeGen
, module Ide.Plugin.Tactic.CodeGen.Utils
) where

import Control.Lens ((%~), (+~), (<>~))
import Control.Lens ((+~))
import Control.Monad.Except
import Control.Monad.State (MonadState)
import Control.Monad.State.Class (modify)
import Data.Generics.Product (field)
import Data.Generics.Product (field)
import Data.List
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Set as S
import Data.Traversable
import DataCon
import Development.IDE.GHC.Compat
Expand All @@ -29,32 +26,16 @@ import Ide.Plugin.Tactic.Judgements
import Ide.Plugin.Tactic.Machinery
import Ide.Plugin.Tactic.Naming
import Ide.Plugin.Tactic.Types
import Type hiding (Var)
import Type hiding (Var)


useOccName :: MonadState TacticState m => Judgement -> OccName -> m ()
useOccName jdg name =
-- Only score points if this is in the local hypothesis
case M.lookup name $ hyByName $ jLocalHypothesis jdg of
Just{} -> modify
$ (withUsedVals $ S.insert name)
. (field @"ts_unused_top_vals" %~ S.delete name)
Nothing -> pure ()


------------------------------------------------------------------------------
-- | Doing recursion incurs a small penalty in the score.
countRecursiveCall :: TacticState -> TacticState
countRecursiveCall = field @"ts_recursion_count" +~ 1


------------------------------------------------------------------------------
-- | Insert some values into the unused top values field. These are
-- subsequently removed via 'useOccName'.
addUnusedTopVals :: MonadState TacticState m => S.Set OccName -> m ()
addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals


destructMatches
:: (DataCon -> Judgement -> Rule)
-- ^ How to construct each match
Expand All @@ -63,7 +44,7 @@ destructMatches
-> CType
-- ^ Type being destructed
-> Judgement
-> RuleM (Trace, [RawMatch])
-> RuleM (Synthesized [RawMatch])
destructMatches f scrut t jdg = do
let hy = jEntireHypothesis jdg
g = jGoal jdg
Expand All @@ -76,16 +57,21 @@ destructMatches f scrut t jdg = do
_ -> fmap unzipTrace $ for dcs $ \dc -> do
let args = dataConInstOrigArgTys' dc apps
names <- mkManyGoodNames (hyNamesInScope hy) args
let hy' = zip names $ coerce args
j = introducingPat scrut dc hy'
let hy' = patternHypothesis scrut dc jdg
$ zip names
$ coerce args
j = introduce hy'
$ withNewGoal g jdg
(tr, sg) <- f dc j
modify $ withIntroducedVals $ mappend $ S.fromList names
pure ( rose ("match " <> show dc <> " {" <>
Synthesized tr sc uv sg <- f dc j
pure
$ Synthesized
( rose ("match " <> show dc <> " {" <>
intercalate ", " (fmap show names) <> "}")
$ pure tr
, match [mkDestructPat dc names] $ unLoc sg
)
$ pure tr)
(sc <> hy')
uv
$ match [mkDestructPat dc names]
$ unLoc sg


------------------------------------------------------------------------------
Expand Down Expand Up @@ -114,10 +100,8 @@ infixifyPatIfNecessary dcon x



unzipTrace :: [(Trace, a)] -> (Trace, [a])
unzipTrace l =
let (trs, as) = unzip l
in (rose mempty trs, as)
unzipTrace :: [Synthesized a] -> Synthesized [a]
unzipTrace = sequenceA


-- | Essentially same as 'dataConInstOrigArgTys' in GHC,
Expand Down Expand Up @@ -154,16 +138,19 @@ destruct' :: (DataCon -> Judgement -> Rule) -> HyInfo CType -> Judgement -> Rule
destruct' f hi jdg = do
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
let term = hi_name hi
useOccName jdg term
(tr, ms)
Synthesized tr sc uv ms
<- destructMatches
f
(Just term)
(hi_type hi)
$ disallowing AlreadyDestructed [term] jdg
pure ( rose ("destruct " <> show term) $ pure tr
, noLoc $ case' (var' term) ms
)
pure
$ Synthesized
(rose ("destruct " <> show term) $ pure tr)
sc
(S.insert term uv)
$ noLoc
$ case' (var' term) ms


------------------------------------------------------------------------------
Expand All @@ -186,10 +173,10 @@ buildDataCon
:: Judgement
-> DataCon -- ^ The data con to build
-> [Type] -- ^ Type arguments for the data con
-> RuleM (Trace, LHsExpr GhcPs)
-> RuleM (Synthesized (LHsExpr GhcPs))
buildDataCon jdg dc tyapps = do
let args = dataConInstOrigArgTys' dc tyapps
(tr, sgs)
Synthesized tr sc uv sgs
<- fmap unzipTrace
$ traverse ( \(arg, n) ->
newSubgoal
Expand All @@ -199,6 +186,6 @@ buildDataCon jdg dc tyapps = do
$ CType arg
) $ zip args [0..]
pure
. (rose (show dc) $ pure tr,)
$ Synthesized (rose (show dc) $ pure tr) sc uv
$ mkCon dc sgs

90 changes: 32 additions & 58 deletions plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Judgements.hs
Original file line number Diff line number Diff line change
@@ -1,33 +1,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module Ide.Plugin.Tactic.Judgements
( blacklistingDestruct
, unwhitelistingSplit
, introducingLambda
, introducingRecursively
, introducingPat
, jGoal
, jHypothesis
, jEntireHypothesis
, jPatHypothesis
, substJdg
, unsetIsTopHole
, filterSameTypeFromOtherPositions
, isDestructBlacklisted
, withNewGoal
, jLocalHypothesis
, isSplitWhitelisted
, isPatternMatch
, filterPosition
, isTopHole
, disallowing
, mkFirstJudgement
, hypothesisFromBindings
, isTopLevel
, hyNamesInScope
, hyByName
) where
module Ide.Plugin.Tactic.Judgements where

import Control.Arrow
import Control.Lens hiding (Context)
Expand Down Expand Up @@ -89,35 +63,39 @@ withNewGoal :: a -> Judgement' a -> Judgement' a
withNewGoal t = field @"_jGoal" .~ t


introduce :: Hypothesis a -> Judgement' a -> Judgement' a
introduce hy = field @"_jHypothesis" <>~ hy


------------------------------------------------------------------------------
-- | Helper function for implementing functions which introduce new hypotheses.
introducing
:: (Int -> Provenance) -- ^ A function from the position of the arg to its
-- provenance.
introduceHypothesis
:: (Int -> Int -> Provenance)
-- ^ A function from the total number of args and position of this arg
-- to its provenance.
-> [(OccName, a)]
-> Judgement' a
-> Judgement' a
introducing f ns =
field @"_jHypothesis" <>~ (Hypothesis $ zip [0..] ns <&>
\(pos, (name, ty)) -> HyInfo name (f pos) ty)
-> Hypothesis a
introduceHypothesis f ns =
Hypothesis $ zip [0..] ns <&> \(pos, (name, ty)) ->
HyInfo name (f (length ns) pos) ty


------------------------------------------------------------------------------
-- | Introduce bindings in the context of a lamba.
introducingLambda
lambdaHypothesis
:: Maybe OccName -- ^ The name of the top level function. For any other
-- function, this should be 'Nothing'.
-> [(OccName, a)]
-> Judgement' a
-> Judgement' a
introducingLambda func = introducing $ \pos ->
maybe UserPrv (\x -> TopLevelArgPrv x pos) func
-> Hypothesis a
lambdaHypothesis func =
introduceHypothesis $ \count pos ->
maybe UserPrv (\x -> TopLevelArgPrv x pos count) func


------------------------------------------------------------------------------
-- | Introduce a binding in a recursive context.
introducingRecursively :: [(OccName, a)] -> Judgement' a -> Judgement' a
introducingRecursively = introducing $ const RecursivePrv
recursiveHypothesis :: [(OccName, a)] -> Hypothesis a
recursiveHypothesis = introduceHypothesis $ const $ const RecursivePrv


------------------------------------------------------------------------------
Expand Down Expand Up @@ -176,7 +154,7 @@ findPositionVal jdg defn pos = listToMaybe $ do
-- ancstry through potentially disallowed terms in the hypothesis.
(name, hi) <- M.toList $ M.map (overProvenance expandDisallowed) $ hyByName $ jEntireHypothesis jdg
case hi_provenance hi of
TopLevelArgPrv defn' pos'
TopLevelArgPrv defn' pos' _
| defn == defn'
, pos == pos' -> pure name
PatternMatchPrv pv
Expand Down Expand Up @@ -243,26 +221,22 @@ extremelyStupid__definingFunction =
fst . head . ctxDefiningFuncs


------------------------------------------------------------------------------
-- | Pattern vals are currently tracked in jHypothesis, with an extra piece of
-- data sitting around in jPatternVals.
introducingPat
patternHypothesis
:: Maybe OccName
-> DataCon
-> [(OccName, a)]
-> Judgement' a
-> Judgement' a
introducingPat scrutinee dc ns jdg
= introducing (\pos ->
-> [(OccName, a)]
-> Hypothesis a
patternHypothesis scrutinee dc jdg
= introduceHypothesis $ \_ pos ->
PatternMatchPrv $
PatVal
scrutinee
(maybe mempty
(\scrut -> S.singleton scrut <> getAncestry jdg scrut)
scrutinee)
(Uniquely dc)
pos
) ns jdg
scrutinee
(maybe mempty
(\scrut -> S.singleton scrut <> getAncestry jdg scrut)
scrutinee)
(Uniquely dc)
pos


------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ deriveArbitrary = do
terminal_expr = mkVal "terminal"
oneof_expr = mkVal "oneof"
pure
( tracePrim "deriveArbitrary"
, noLoc $
$ Synthesized (tracePrim "deriveArbitrary")
-- TODO(sandy): This thing is not actually empty! We produced
-- a bespoke binding "terminal", and a not-so-bespoke "n".
-- But maybe it's fine for known rules?
mempty
mempty
$ noLoc $
let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $
appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $
case' (infixCall "<=" (mkVal "n") (int 1))
Expand All @@ -57,7 +62,6 @@ deriveArbitrary = do
(list $ fmap genExpr big)
terminal_expr
]
)
_ -> throwError $ GoalMismatch "deriveArbitrary" ty


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ getRhsPosVals rss tcs
, isHole $ occName hole -- and the span is a hole
-> First $ do
patnames <- traverse getPatName ps
pure $ zip patnames $ [0..] <&> TopLevelArgPrv name
pure $ zip patnames $ [0..] <&> \n ->
TopLevelArgPrv name n (length patnames)
_ -> mempty
) tcs

Expand Down
Loading