Skip to content

fix: credentials: block output while running credential tools #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/BurntSushi/locker"
"github.com/google/shlex"
"github.com/gptscript-ai/gptscript/pkg/confirm"
"github.com/gptscript-ai/gptscript/pkg/runner"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/jaytaylor/html2text"
)
Expand Down Expand Up @@ -647,15 +646,7 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
return params.Location, nil
}

func SysPrompt(ctx context.Context, _ []string, input string) (_ string, err error) {
monitor := ctx.Value(runner.MonitorKey{})
if monitor == nil {
return "", errors.New("no monitor in context")
}

unpause := monitor.(runner.Monitor).Pause()
defer unpause()

func SysPrompt(_ context.Context, _ []string, input string) (_ string, err error) {
var params struct {
Message string `json:"message,omitempty"`
Fields string `json:"fields,omitempty"`
Expand Down
19 changes: 19 additions & 0 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package context

import (
"context"
)

type pauseKey struct{}

func AddPauseFuncToCtx(ctx context.Context, pauseF func() func()) context.Context {
return context.WithValue(ctx, pauseKey{}, pauseF)
}

func GetPauseFuncFromCtx(ctx context.Context) func() func() {
pauseF, ok := ctx.Value(pauseKey{}).(func() func())
if !ok {
return func() func() { return func() {} }
}
return pauseF
}
9 changes: 8 additions & 1 deletion pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ import (
"sync/atomic"

"github.com/google/shlex"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/env"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
)

func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string) (cmdOut string, cmdErr error) {
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, isCredential bool) (cmdOut string, cmdErr error) {
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))

defer func() {
Expand Down Expand Up @@ -64,6 +65,12 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
cmd.Stderr = io.MultiWriter(all, os.Stderr)
cmd.Stdout = io.MultiWriter(all, output)

if isCredential {
pause := context2.GetPauseFuncFromCtx(ctx)
unpause := pause()
defer unpause()
Comment on lines +69 to +71
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could technically be simplified down to defer context2.GetPauseFuncFromCtx(ctx)()() but I was afraid that would be too confusing, so I split it into three lines to make it clearer what is happening here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of me would love to see this as defer context2.GetPauseFuncFromCtx(ctx)()(), but I think the split is the right idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then you'd be a terrible person. What maniac would do a chained function call in a single line defer :)

}

if err := cmd.Run(); err != nil {
_, _ = os.Stderr.Write(output.Bytes())
log.Errorf("failed to run tool [%s] cmd %v: %v", tool.Parameters.Name, cmd.Args, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
} else if tool.IsOpenAPI() {
return e.runOpenAPI(tool, input)
}
s, err := e.runCommand(ctx.WrappedContext(), tool, input)
s, err := e.runCommand(ctx.WrappedContext(), tool, input, ctx.IsCredential)
if err != nil {
return nil, err
}
Expand Down
8 changes: 2 additions & 6 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/gptscript-ai/gptscript/pkg/config"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand All @@ -27,8 +28,6 @@ type Monitor interface {
Stop(output string, err error)
}

type MonitorKey struct{}

type Options struct {
MonitorFactory MonitorFactory `usage:"-"`
RuntimeManager engine.RuntimeManager `usage:"-"`
Expand Down Expand Up @@ -178,10 +177,7 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp
Content: input,
})

// The sys.prompt tool is a special case where we need to pass the monitor to the builtin function.
if callCtx.Tool.BuiltinFunc != nil && callCtx.Tool.ID == "sys.prompt" {
callCtx.Ctx = context.WithValue(callCtx.Ctx, MonitorKey{}, monitor)
}
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)

result, err := e.Start(callCtx, input)
if err != nil {
Expand Down