Skip to content

Commit 5c8dde2

Browse files
feat: add chat
1 parent 0ea7ed3 commit 5c8dde2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2735
-292
lines changed

go.mod

+9-8
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ require (
88
github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d
99
github.com/acorn-io/cmd v0.0.0-20240404013709-34f690bde37b
1010
github.com/adrg/xdg v0.4.0
11+
github.com/chzyer/readline v1.5.1
1112
github.com/docker/cli v26.0.0+incompatible
1213
github.com/docker/docker-credential-helpers v0.8.1
1314
github.com/fatih/color v1.16.0
1415
github.com/getkin/kin-openapi v0.123.0
1516
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1617
github.com/gptscript-ai/chat-completion-client v0.0.0-20240404013040-49eb8f6affa1
17-
github.com/hexops/autogold/v2 v2.1.0
18+
github.com/hexops/autogold/v2 v2.2.1
1819
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
1920
github.com/mholt/archiver/v4 v4.0.0-alpha.8
2021
github.com/olahol/melody v1.1.4
@@ -26,8 +27,8 @@ require (
2627
github.com/stretchr/testify v1.8.4
2728
github.com/tidwall/gjson v1.17.1
2829
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
29-
golang.org/x/sync v0.6.0
30-
golang.org/x/term v0.16.0
30+
golang.org/x/sync v0.7.0
31+
golang.org/x/term v0.19.0
3132
gopkg.in/yaml.v3 v3.0.1
3233
)
3334

@@ -48,7 +49,7 @@ require (
4849
github.com/hashicorp/go-multierror v1.1.1 // indirect
4950
github.com/hexops/autogold v1.3.1 // indirect
5051
github.com/hexops/gotextdiff v1.0.3 // indirect
51-
github.com/hexops/valast v1.4.3 // indirect
52+
github.com/hexops/valast v1.4.4 // indirect
5253
github.com/inconshreveable/mousetrap v1.1.0 // indirect
5354
github.com/invopop/yaml v0.2.0 // indirect
5455
github.com/josharian/intern v1.0.0 // indirect
@@ -75,11 +76,11 @@ require (
7576
github.com/tidwall/pretty v1.2.0 // indirect
7677
github.com/ulikunitz/xz v0.5.10 // indirect
7778
go4.org v0.0.0-20200411211856-f5505b9728dd // indirect
78-
golang.org/x/mod v0.15.0 // indirect
79-
golang.org/x/net v0.20.0 // indirect
80-
golang.org/x/sys v0.16.0 // indirect
79+
golang.org/x/mod v0.17.0 // indirect
80+
golang.org/x/net v0.24.0 // indirect
81+
golang.org/x/sys v0.19.0 // indirect
8182
golang.org/x/text v0.14.0 // indirect
82-
golang.org/x/tools v0.17.0 // indirect
83+
golang.org/x/tools v0.20.0 // indirect
8384
gotest.tools/v3 v3.5.1 // indirect
8485
mvdan.cc/gofumpt v0.6.0 // indirect
8586
)

go.sum

+52-15
Large diffs are not rendered by default.

pkg/builtin/builtin.go

+29
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ var tools = map[string]types.Tool{
8585
},
8686
BuiltinFunc: SysAbort,
8787
},
88+
"sys.chat.finish": {
89+
Parameters: types.Parameters{
90+
Description: "Finishes a chat dialog. After this call the conversation will conclude.",
91+
Arguments: types.ObjectSchema(
92+
"message", "The final message to return to the user",
93+
),
94+
},
95+
BuiltinFunc: SysChatFinish,
96+
},
8897
"sys.http.post": {
8998
Parameters: types.Parameters{
9099
Description: "Write contents to a http or https URL using the POST method",
@@ -524,6 +533,26 @@ func SysGetenv(ctx context.Context, env []string, input string) (string, error)
524533
return os.Getenv(params.Name), nil
525534
}
526535

536+
type ErrChatFinish struct {
537+
Message string
538+
}
539+
540+
func (e *ErrChatFinish) Error() string {
541+
return fmt.Sprintf("CHAT FINISH: %s", e.Message)
542+
}
543+
544+
func SysChatFinish(ctx context.Context, env []string, input string) (string, error) {
545+
var params struct {
546+
Message string `json:"message,omitempty"`
547+
}
548+
if err := json.Unmarshal([]byte(input), &params); err != nil {
549+
return "", err
550+
}
551+
return "", &ErrChatFinish{
552+
Message: params.Message,
553+
}
554+
}
555+
527556
func SysAbort(ctx context.Context, env []string, input string) (string, error) {
528557
var params struct {
529558
Message string `json:"message,omitempty"`

pkg/chat/chat.go

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package chat
2+
3+
import (
4+
"context"
5+
6+
"github.com/fatih/color"
7+
"github.com/gptscript-ai/gptscript/pkg/runner"
8+
"github.com/gptscript-ai/gptscript/pkg/types"
9+
)
10+
11+
type Prompter interface {
12+
Readline() (string, bool, error)
13+
Printf(format string, args ...interface{}) (int, error)
14+
SetPrompt(p string)
15+
Close() error
16+
}
17+
18+
type Chatter interface {
19+
Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (resp runner.ChatResponse, err error)
20+
}
21+
22+
type GetProgram func() (types.Program, error)
23+
24+
func getPrompt(prg types.Program, resp runner.ChatResponse) string {
25+
name := prg.ChatName()
26+
if newName := prg.ToolSet[resp.ToolID].Name; newName != "" {
27+
name = newName
28+
}
29+
30+
return color.GreenString("%s> ", name)
31+
}
32+
33+
func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg GetProgram, env []string, startInput string) error {
34+
var (
35+
prompter Prompter
36+
)
37+
38+
prompter, err := newReadlinePrompter()
39+
if err != nil {
40+
return err
41+
}
42+
defer prompter.Close()
43+
44+
for {
45+
var (
46+
input string
47+
ok bool
48+
resp runner.ChatResponse
49+
)
50+
51+
prg, err := prg()
52+
if err != nil {
53+
return err
54+
}
55+
56+
prompter.SetPrompt(getPrompt(prg, resp))
57+
58+
if startInput != "" {
59+
input = startInput
60+
startInput = ""
61+
} else {
62+
input, ok, err = prompter.Readline()
63+
if !ok || err != nil {
64+
return err
65+
}
66+
}
67+
68+
resp, err = chatter.Chat(ctx, prevState, prg, env, input)
69+
if err != nil || resp.Done {
70+
return err
71+
}
72+
73+
if resp.Content != "" {
74+
_, err := prompter.Printf(color.RedString("< %s\n", resp.Content))
75+
if err != nil {
76+
return err
77+
}
78+
}
79+
80+
prevState = resp.State
81+
}
82+
}

pkg/chat/readline.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package chat
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"strings"
8+
9+
"github.com/chzyer/readline"
10+
"github.com/fatih/color"
11+
"github.com/gptscript-ai/gptscript/pkg/mvl"
12+
)
13+
14+
var _ Prompter = (*readlinePrompter)(nil)
15+
16+
type readlinePrompter struct {
17+
readliner *readline.Instance
18+
}
19+
20+
func newReadlinePrompter() (*readlinePrompter, error) {
21+
l, err := readline.NewEx(&readline.Config{
22+
Prompt: color.GreenString("> "),
23+
InterruptPrompt: "^C",
24+
EOFPrompt: "exit",
25+
HistorySearchFold: true,
26+
})
27+
if err != nil {
28+
return nil, err
29+
}
30+
31+
l.CaptureExitSignal()
32+
mvl.SetOutput(l.Stderr())
33+
34+
return &readlinePrompter{
35+
readliner: l,
36+
}, nil
37+
}
38+
39+
func (r *readlinePrompter) Printf(format string, args ...interface{}) (int, error) {
40+
return fmt.Fprintf(r.readliner.Stdout(), format, args...)
41+
}
42+
43+
func (r *readlinePrompter) Readline() (string, bool, error) {
44+
line, err := r.readliner.Readline()
45+
if errors.Is(err, readline.ErrInterrupt) {
46+
return "", false, nil
47+
} else if errors.Is(err, io.EOF) {
48+
return "", false, nil
49+
}
50+
return strings.TrimSpace(line), true, nil
51+
}
52+
53+
func (r *readlinePrompter) SetPrompt(prompt string) {
54+
r.readliner.SetPrompt(prompt)
55+
}
56+
57+
func (r *readlinePrompter) Close() error {
58+
return r.readliner.Close()
59+
}

pkg/cli/gptscript.go

+36-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cli
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"io"
78
"os"
@@ -14,6 +15,7 @@ import (
1415
"github.com/gptscript-ai/gptscript/pkg/assemble"
1516
"github.com/gptscript-ai/gptscript/pkg/builtin"
1617
"github.com/gptscript-ai/gptscript/pkg/cache"
18+
"github.com/gptscript-ai/gptscript/pkg/chat"
1719
"github.com/gptscript-ai/gptscript/pkg/confirm"
1820
"github.com/gptscript-ai/gptscript/pkg/gptscript"
1921
"github.com/gptscript-ai/gptscript/pkg/input"
@@ -57,6 +59,10 @@ type GPTScript struct {
5759
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
5860
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
5961
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
62+
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state"`
63+
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool"`
64+
65+
readData []byte
6066
}
6167

6268
func New() *cobra.Command {
@@ -245,9 +251,18 @@ func (r *GPTScript) readProgram(ctx context.Context, args []string) (prg types.P
245251
}
246252

247253
if args[0] == "-" {
248-
data, err := io.ReadAll(os.Stdin)
249-
if err != nil {
250-
return prg, err
254+
var (
255+
data []byte
256+
err error
257+
)
258+
if len(r.readData) > 0 {
259+
data = r.readData
260+
} else {
261+
data, err = io.ReadAll(os.Stdin)
262+
if err != nil {
263+
return prg, err
264+
}
265+
r.readData = data
251266
}
252267
return loader.ProgramFromSource(ctx, string(data), r.SubTool)
253268
}
@@ -349,6 +364,24 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
349364
return err
350365
}
351366

367+
if r.ChatState != "" {
368+
resp, err := gptScript.Chat(r.NewRunContext(cmd), r.ChatState, prg, os.Environ(), toolInput)
369+
if err != nil {
370+
return err
371+
}
372+
data, err := json.Marshal(resp)
373+
if err != nil {
374+
return err
375+
}
376+
return r.PrintOutput(toolInput, string(data))
377+
}
378+
379+
if prg.IsChat() || r.ForceChat {
380+
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
381+
return r.readProgram(ctx, args)
382+
}, os.Environ(), toolInput)
383+
}
384+
352385
s, err := gptScript.Run(r.NewRunContext(cmd), prg, os.Environ(), toolInput)
353386
if err != nil {
354387
return err

pkg/engine/cmd.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import (
2121
"github.com/gptscript-ai/gptscript/pkg/version"
2222
)
2323

24-
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, isCredential bool) (cmdOut string, cmdErr error) {
24+
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) {
2525
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))
2626

2727
defer func() {
@@ -65,7 +65,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string,
6565
cmd.Stderr = io.MultiWriter(all, os.Stderr)
6666
cmd.Stdout = io.MultiWriter(all, output)
6767

68-
if isCredential {
68+
if toolCategory == CredentialToolCategory {
6969
pause := context2.GetPauseFuncFromCtx(ctx)
7070
unpause := pause()
7171
defer unpause()

0 commit comments

Comments
 (0)