Skip to content

Commit 463fe29

Browse files
authored
fix: credentials: block output while running credential tools (#261)
Signed-off-by: Grant Linville <[email protected]>
1 parent 62032ec commit 463fe29

File tree

5 files changed

+31
-18
lines changed

5 files changed

+31
-18
lines changed

pkg/builtin/builtin.go

+1-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"github.com/BurntSushi/locker"
2222
"github.com/google/shlex"
2323
"github.com/gptscript-ai/gptscript/pkg/confirm"
24-
"github.com/gptscript-ai/gptscript/pkg/runner"
2524
"github.com/gptscript-ai/gptscript/pkg/types"
2625
"github.com/jaytaylor/html2text"
2726
)
@@ -647,15 +646,7 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
647646
return params.Location, nil
648647
}
649648

650-
func SysPrompt(ctx context.Context, _ []string, input string) (_ string, err error) {
651-
monitor := ctx.Value(runner.MonitorKey{})
652-
if monitor == nil {
653-
return "", errors.New("no monitor in context")
654-
}
655-
656-
unpause := monitor.(runner.Monitor).Pause()
657-
defer unpause()
658-
649+
func SysPrompt(_ context.Context, _ []string, input string) (_ string, err error) {
659650
var params struct {
660651
Message string `json:"message,omitempty"`
661652
Fields string `json:"fields,omitempty"`

pkg/context/context.go

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package context
2+
3+
import (
4+
"context"
5+
)
6+
7+
type pauseKey struct{}
8+
9+
func AddPauseFuncToCtx(ctx context.Context, pauseF func() func()) context.Context {
10+
return context.WithValue(ctx, pauseKey{}, pauseF)
11+
}
12+
13+
func GetPauseFuncFromCtx(ctx context.Context) func() func() {
14+
pauseF, ok := ctx.Value(pauseKey{}).(func() func())
15+
if !ok {
16+
return func() func() { return func() {} }
17+
}
18+
return pauseF
19+
}

pkg/engine/cmd.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ import (
1515
"sync/atomic"
1616

1717
"github.com/google/shlex"
18+
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1819
"github.com/gptscript-ai/gptscript/pkg/env"
1920
"github.com/gptscript-ai/gptscript/pkg/types"
2021
"github.com/gptscript-ai/gptscript/pkg/version"
2122
)
2223

23-
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string) (cmdOut string, cmdErr error) {
24+
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, isCredential bool) (cmdOut string, cmdErr error) {
2425
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))
2526

2627
defer func() {
@@ -64,6 +65,12 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
6465
cmd.Stderr = io.MultiWriter(all, os.Stderr)
6566
cmd.Stdout = io.MultiWriter(all, output)
6667

68+
if isCredential {
69+
pause := context2.GetPauseFuncFromCtx(ctx)
70+
unpause := pause()
71+
defer unpause()
72+
}
73+
6774
if err := cmd.Run(); err != nil {
6875
_, _ = os.Stderr.Write(output.Bytes())
6976
log.Errorf("failed to run tool [%s] cmd %v: %v", tool.Parameters.Name, cmd.Args, err)

pkg/engine/engine.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
171171
} else if tool.IsOpenAPI() {
172172
return e.runOpenAPI(tool, input)
173173
}
174-
s, err := e.runCommand(ctx.WrappedContext(), tool, input)
174+
s, err := e.runCommand(ctx.WrappedContext(), tool, input, ctx.IsCredential)
175175
if err != nil {
176176
return nil, err
177177
}

pkg/runner/runner.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/gptscript-ai/gptscript/pkg/config"
14+
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1415
"github.com/gptscript-ai/gptscript/pkg/credentials"
1516
"github.com/gptscript-ai/gptscript/pkg/engine"
1617
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -27,8 +28,6 @@ type Monitor interface {
2728
Stop(output string, err error)
2829
}
2930

30-
type MonitorKey struct{}
31-
3231
type Options struct {
3332
MonitorFactory MonitorFactory `usage:"-"`
3433
RuntimeManager engine.RuntimeManager `usage:"-"`
@@ -182,10 +181,7 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp
182181
Content: input,
183182
})
184183

185-
// The sys.prompt tool is a special case where we need to pass the monitor to the builtin function.
186-
if callCtx.Tool.BuiltinFunc != nil && callCtx.Tool.ID == "sys.prompt" {
187-
callCtx.Ctx = context.WithValue(callCtx.Ctx, MonitorKey{}, monitor)
188-
}
184+
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
189185

190186
result, err := e.Start(callCtx, input)
191187
if err != nil {

0 commit comments

Comments
 (0)