Skip to content

Commit 999d0b4

Browse files
Merge pull request #277 from ibuildthecloud/main
feat: add chat
2 parents f488a19 + f3c7f43 commit 999d0b4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2658
-245
lines changed

go.mod

+9-8
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ require (
88
github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d
99
github.com/acorn-io/cmd v0.0.0-20240404013709-34f690bde37b
1010
github.com/adrg/xdg v0.4.0
11+
github.com/chzyer/readline v1.5.1
1112
github.com/docker/cli v26.0.0+incompatible
1213
github.com/docker/docker-credential-helpers v0.8.1
1314
github.com/fatih/color v1.16.0
1415
github.com/getkin/kin-openapi v0.123.0
1516
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1617
github.com/gptscript-ai/chat-completion-client v0.0.0-20240404013040-49eb8f6affa1
17-
github.com/hexops/autogold/v2 v2.1.0
18+
github.com/hexops/autogold/v2 v2.2.1
1819
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
1920
github.com/mholt/archiver/v4 v4.0.0-alpha.8
2021
github.com/olahol/melody v1.1.4
@@ -26,8 +27,8 @@ require (
2627
github.com/stretchr/testify v1.8.4
2728
github.com/tidwall/gjson v1.17.1
2829
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
29-
golang.org/x/sync v0.6.0
30-
golang.org/x/term v0.16.0
30+
golang.org/x/sync v0.7.0
31+
golang.org/x/term v0.19.0
3132
gopkg.in/yaml.v3 v3.0.1
3233
)
3334

@@ -48,7 +49,7 @@ require (
4849
github.com/hashicorp/go-multierror v1.1.1 // indirect
4950
github.com/hexops/autogold v1.3.1 // indirect
5051
github.com/hexops/gotextdiff v1.0.3 // indirect
51-
github.com/hexops/valast v1.4.3 // indirect
52+
github.com/hexops/valast v1.4.4 // indirect
5253
github.com/inconshreveable/mousetrap v1.1.0 // indirect
5354
github.com/invopop/yaml v0.2.0 // indirect
5455
github.com/josharian/intern v1.0.0 // indirect
@@ -75,11 +76,11 @@ require (
7576
github.com/tidwall/pretty v1.2.0 // indirect
7677
github.com/ulikunitz/xz v0.5.10 // indirect
7778
go4.org v0.0.0-20200411211856-f5505b9728dd // indirect
78-
golang.org/x/mod v0.15.0 // indirect
79-
golang.org/x/net v0.20.0 // indirect
80-
golang.org/x/sys v0.16.0 // indirect
79+
golang.org/x/mod v0.17.0 // indirect
80+
golang.org/x/net v0.24.0 // indirect
81+
golang.org/x/sys v0.19.0 // indirect
8182
golang.org/x/text v0.14.0 // indirect
82-
golang.org/x/tools v0.17.0 // indirect
83+
golang.org/x/tools v0.20.0 // indirect
8384
gotest.tools/v3 v3.5.1 // indirect
8485
mvdan.cc/gofumpt v0.6.0 // indirect
8586
)

go.sum

+52-15
Large diffs are not rendered by default.

pkg/builtin/builtin.go

+29
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ var tools = map[string]types.Tool{
8585
},
8686
BuiltinFunc: SysAbort,
8787
},
88+
"sys.chat.finish": {
89+
Parameters: types.Parameters{
90+
Description: "Concludes the conversation. This can not be used to ask a question.",
91+
Arguments: types.ObjectSchema(
92+
"summary", "A summary of the dialog",
93+
),
94+
},
95+
BuiltinFunc: SysChatFinish,
96+
},
8897
"sys.http.post": {
8998
Parameters: types.Parameters{
9099
Description: "Write contents to a http or https URL using the POST method",
@@ -524,6 +533,26 @@ func SysGetenv(ctx context.Context, env []string, input string) (string, error)
524533
return os.Getenv(params.Name), nil
525534
}
526535

536+
type ErrChatFinish struct {
537+
Message string
538+
}
539+
540+
func (e *ErrChatFinish) Error() string {
541+
return fmt.Sprintf("CHAT FINISH: %s", e.Message)
542+
}
543+
544+
func SysChatFinish(ctx context.Context, env []string, input string) (string, error) {
545+
var params struct {
546+
Message string `json:"message,omitempty"`
547+
}
548+
if err := json.Unmarshal([]byte(input), &params); err != nil {
549+
return "", err
550+
}
551+
return "", &ErrChatFinish{
552+
Message: params.Message,
553+
}
554+
}
555+
527556
func SysAbort(ctx context.Context, env []string, input string) (string, error) {
528557
var params struct {
529558
Message string `json:"message,omitempty"`

pkg/chat/chat.go

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package chat
2+
3+
import (
4+
"context"
5+
6+
"github.com/fatih/color"
7+
"github.com/gptscript-ai/gptscript/pkg/runner"
8+
"github.com/gptscript-ai/gptscript/pkg/types"
9+
)
10+
11+
type Prompter interface {
12+
Readline() (string, bool, error)
13+
Printf(format string, args ...interface{}) (int, error)
14+
SetPrompt(p string)
15+
Close() error
16+
}
17+
18+
type Chatter interface {
19+
Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (resp runner.ChatResponse, err error)
20+
}
21+
22+
type GetProgram func() (types.Program, error)
23+
24+
func getPrompt(prg types.Program, resp runner.ChatResponse) string {
25+
name := prg.ChatName()
26+
if newName := prg.ToolSet[resp.ToolID].Name; newName != "" {
27+
name = newName
28+
}
29+
30+
return color.GreenString("%s> ", name)
31+
}
32+
33+
func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg GetProgram, env []string, startInput string) error {
34+
var (
35+
prompter Prompter
36+
)
37+
38+
prompter, err := newReadlinePrompter()
39+
if err != nil {
40+
return err
41+
}
42+
defer prompter.Close()
43+
44+
for {
45+
var (
46+
input string
47+
ok bool
48+
resp runner.ChatResponse
49+
)
50+
51+
prg, err := prg()
52+
if err != nil {
53+
return err
54+
}
55+
56+
prompter.SetPrompt(getPrompt(prg, resp))
57+
58+
if startInput != "" {
59+
input = startInput
60+
startInput = ""
61+
} else if !(prevState == nil && prg.ToolSet[prg.EntryToolID].Arguments == nil) {
62+
// The above logic will skip prompting if this is the first loop and the chat expects no args
63+
input, ok, err = prompter.Readline()
64+
if !ok || err != nil {
65+
return err
66+
}
67+
}
68+
69+
resp, err = chatter.Chat(ctx, prevState, prg, env, input)
70+
if err != nil || resp.Done {
71+
return err
72+
}
73+
74+
if resp.Content != "" {
75+
_, err := prompter.Printf(color.RedString("< %s\n", resp.Content))
76+
if err != nil {
77+
return err
78+
}
79+
}
80+
81+
prevState = resp.State
82+
}
83+
}

pkg/chat/readline.go

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package chat
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"strings"
8+
9+
"github.com/adrg/xdg"
10+
"github.com/chzyer/readline"
11+
"github.com/fatih/color"
12+
"github.com/gptscript-ai/gptscript/pkg/mvl"
13+
)
14+
15+
var _ Prompter = (*readlinePrompter)(nil)
16+
17+
type readlinePrompter struct {
18+
readliner *readline.Instance
19+
}
20+
21+
func newReadlinePrompter() (*readlinePrompter, error) {
22+
historyFile, err := xdg.CacheFile("gptscript/chat.history")
23+
if err != nil {
24+
historyFile = ""
25+
}
26+
27+
l, err := readline.NewEx(&readline.Config{
28+
Prompt: color.GreenString("> "),
29+
HistoryFile: historyFile,
30+
InterruptPrompt: "^C",
31+
EOFPrompt: "exit",
32+
HistorySearchFold: true,
33+
})
34+
if err != nil {
35+
return nil, err
36+
}
37+
38+
l.CaptureExitSignal()
39+
mvl.SetOutput(l.Stderr())
40+
41+
return &readlinePrompter{
42+
readliner: l,
43+
}, nil
44+
}
45+
46+
func (r *readlinePrompter) Printf(format string, args ...interface{}) (int, error) {
47+
return fmt.Fprintf(r.readliner.Stdout(), format, args...)
48+
}
49+
50+
func (r *readlinePrompter) Readline() (string, bool, error) {
51+
line, err := r.readliner.Readline()
52+
if errors.Is(err, readline.ErrInterrupt) {
53+
return "", false, nil
54+
} else if errors.Is(err, io.EOF) {
55+
return "", false, nil
56+
}
57+
return strings.TrimSpace(line), true, nil
58+
}
59+
60+
func (r *readlinePrompter) SetPrompt(prompt string) {
61+
r.readliner.SetPrompt(prompt)
62+
}
63+
64+
func (r *readlinePrompter) Close() error {
65+
return r.readliner.Close()
66+
}

pkg/cli/eval.go

+9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"strconv"
77
"strings"
88

9+
"github.com/gptscript-ai/gptscript/pkg/chat"
910
"github.com/gptscript-ai/gptscript/pkg/gptscript"
1011
"github.com/gptscript-ai/gptscript/pkg/input"
1112
"github.com/gptscript-ai/gptscript/pkg/loader"
@@ -15,6 +16,7 @@ import (
1516

1617
type Eval struct {
1718
Tools []string `usage:"Tools available to call"`
19+
Chat bool `usage:"Enable chat"`
1820
MaxTokens int `usage:"Maximum number of tokens to output"`
1921
Model string `usage:"The model to use"`
2022
JSON bool `usage:"Output JSON"`
@@ -33,6 +35,7 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error {
3335
ModelName: e.Model,
3436
JSONResponse: e.JSON,
3537
InternalPrompt: e.InternalPrompt,
38+
Chat: e.Chat,
3639
},
3740
Instructions: strings.Join(args, " "),
3841
}
@@ -66,6 +69,12 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error {
6669
return err
6770
}
6871

72+
if e.Chat {
73+
return chat.Start(e.gptscript.NewRunContext(cmd), nil, runner, func() (types.Program, error) {
74+
return prg, nil
75+
}, os.Environ(), toolInput)
76+
}
77+
6978
toolOutput, err := runner.Run(e.gptscript.NewRunContext(cmd), prg, os.Environ(), toolInput)
7079
if err != nil {
7180
return err

pkg/cli/gptscript.go

+42-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cli
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"io"
78
"os"
@@ -14,6 +15,7 @@ import (
1415
"github.com/gptscript-ai/gptscript/pkg/assemble"
1516
"github.com/gptscript-ai/gptscript/pkg/builtin"
1617
"github.com/gptscript-ai/gptscript/pkg/cache"
18+
"github.com/gptscript-ai/gptscript/pkg/chat"
1719
"github.com/gptscript-ai/gptscript/pkg/confirm"
1820
"github.com/gptscript-ai/gptscript/pkg/gptscript"
1921
"github.com/gptscript-ai/gptscript/pkg/input"
@@ -57,6 +59,10 @@ type GPTScript struct {
5759
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
5860
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
5961
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
62+
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state"`
63+
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool"`
64+
65+
readData []byte
6066
}
6167

6268
func New() *cobra.Command {
@@ -207,11 +213,17 @@ func (r *GPTScript) PersistentPre(*cobra.Command, []string) error {
207213
r.Quiet = new(bool)
208214
} else {
209215
r.Quiet = &[]bool{true}[0]
216+
if r.Color == nil {
217+
r.Color = new(bool)
218+
}
210219
}
211220
}
212221

213222
if r.Debug {
214223
mvl.SetDebug()
224+
if r.Color == nil {
225+
r.Color = new(bool)
226+
}
215227
} else {
216228
mvl.SetSimpleFormat()
217229
if *r.Quiet {
@@ -245,9 +257,18 @@ func (r *GPTScript) readProgram(ctx context.Context, args []string) (prg types.P
245257
}
246258

247259
if args[0] == "-" {
248-
data, err := io.ReadAll(os.Stdin)
249-
if err != nil {
250-
return prg, err
260+
var (
261+
data []byte
262+
err error
263+
)
264+
if len(r.readData) > 0 {
265+
data = r.readData
266+
} else {
267+
data, err = io.ReadAll(os.Stdin)
268+
if err != nil {
269+
return prg, err
270+
}
271+
r.readData = data
251272
}
252273
return loader.ProgramFromSource(ctx, string(data), r.SubTool)
253274
}
@@ -349,6 +370,24 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
349370
return err
350371
}
351372

373+
if r.ChatState != "" {
374+
resp, err := gptScript.Chat(r.NewRunContext(cmd), r.ChatState, prg, os.Environ(), toolInput)
375+
if err != nil {
376+
return err
377+
}
378+
data, err := json.Marshal(resp)
379+
if err != nil {
380+
return err
381+
}
382+
return r.PrintOutput(toolInput, string(data))
383+
}
384+
385+
if prg.IsChat() || r.ForceChat {
386+
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
387+
return r.readProgram(ctx, args)
388+
}, os.Environ(), toolInput)
389+
}
390+
352391
s, err := gptScript.Run(r.NewRunContext(cmd), prg, os.Environ(), toolInput)
353392
if err != nil {
354393
return err

pkg/engine/cmd.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import (
2121
"github.com/gptscript-ai/gptscript/pkg/version"
2222
)
2323

24-
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, isCredential bool) (cmdOut string, cmdErr error) {
24+
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) {
2525
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))
2626

2727
defer func() {
@@ -65,7 +65,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string,
6565
cmd.Stderr = io.MultiWriter(all, os.Stderr)
6666
cmd.Stdout = io.MultiWriter(all, output)
6767

68-
if isCredential {
68+
if toolCategory == CredentialToolCategory {
6969
pause := context2.GetPauseFuncFromCtx(ctx)
7070
unpause := pause()
7171
defer unpause()

0 commit comments

Comments
 (0)