Skip to content

Commit 33741b1

Browse files
chore: refactor logic for tool sharing
1 parent fbb8f5d commit 33741b1

File tree

17 files changed

+211
-313
lines changed

17 files changed

+211
-313
lines changed

Diff for: pkg/engine/engine.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ func NewContext(ctx context.Context, prg *types.Program, input string) (Context,
204204
Input: input,
205205
}
206206

207-
agentGroup, err := callCtx.Tool.GetAgents(*prg)
207+
agentGroup, err := callCtx.Tool.GetToolsByType(prg, types.ToolTypeAgent)
208208
if err != nil {
209209
return callCtx, err
210210
}
@@ -225,7 +225,7 @@ func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID stri
225225
callID = counter.Next()
226226
}
227227

228-
agentGroup, err := c.Tool.GetNextAgentGroup(*c.Program, c.AgentGroup, toolID)
228+
agentGroup, err := c.Tool.GetNextAgentGroup(c.Program, c.AgentGroup, toolID)
229229
if err != nil {
230230
return Context{}, err
231231
}
@@ -272,7 +272,7 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too
272272
}
273273

274274
var err error
275-
completion.Tools, err = tool.GetCompletionTools(*ctx.Program, ctx.AgentGroup...)
275+
completion.Tools, err = tool.GetChatCompletionTools(*ctx.Program, ctx.AgentGroup...)
276276
if err != nil {
277277
return err
278278
}

Diff for: pkg/runner/input.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ import (
55
"fmt"
66

77
"github.com/gptscript-ai/gptscript/pkg/engine"
8+
"github.com/gptscript-ai/gptscript/pkg/types"
89
)
910

1011
func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) {
11-
inputToolRefs, err := callCtx.Tool.GetInputFilterTools(*callCtx.Program)
12+
inputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeInput)
1213
if err != nil {
1314
return "", err
1415
}

Diff for: pkg/runner/output.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ import (
66
"fmt"
77

88
"github.com/gptscript-ai/gptscript/pkg/engine"
9+
"github.com/gptscript-ai/gptscript/pkg/types"
910
)
1011

1112
func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) {
12-
outputToolRefs, err := callCtx.Tool.GetOutputFilterTools(*callCtx.Program)
13+
outputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeOutput)
1314
if err != nil {
1415
return nil, err
1516
}

Diff for: pkg/runner/runner.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
330330
}
331331

332332
func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) {
333-
toolRefs, err := callCtx.Tool.GetContextTools(*callCtx.Program)
333+
toolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeContext)
334334
if err != nil {
335335
return nil, err
336336
}
@@ -387,7 +387,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
387387
return nil, err
388388
}
389389

390-
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
390+
credTools, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeCredential)
391391
if err != nil {
392392
return nil, err
393393
}
@@ -503,7 +503,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
503503
progress, progressClose := streamProgress(&callCtx, monitor)
504504
defer progressClose()
505505

506-
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
506+
credTools, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeCredential)
507507
if err != nil {
508508
return nil, err
509509
}

Diff for: pkg/tests/testdata/TestAgentOnly/call2.golden

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
"tools": [
55
{
66
"function": {
7-
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
8-
"name": "agent1",
7+
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
8+
"name": "agent3",
99
"parameters": {
1010
"properties": {
1111
"defaultPromptParameter": {
@@ -19,8 +19,8 @@
1919
},
2020
{
2121
"function": {
22-
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
23-
"name": "agent3",
22+
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
23+
"name": "agent1",
2424
"parameters": {
2525
"properties": {
2626
"defaultPromptParameter": {

Diff for: pkg/tests/testdata/TestAgentOnly/step1.golden

+4-4
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@
9696
"tools": [
9797
{
9898
"function": {
99-
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
100-
"name": "agent1",
99+
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
100+
"name": "agent3",
101101
"parameters": {
102102
"properties": {
103103
"defaultPromptParameter": {
@@ -111,8 +111,8 @@
111111
},
112112
{
113113
"function": {
114-
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
115-
"name": "agent3",
114+
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
115+
"name": "agent1",
116116
"parameters": {
117117
"properties": {
118118
"defaultPromptParameter": {

Diff for: pkg/tests/testdata/TestAgents/call3-resp.golden

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"content": [
44
{
55
"toolCall": {
6-
"index": 1,
6+
"index": 0,
77
"id": "call_3",
88
"function": {
99
"name": "agent3"

Diff for: pkg/tests/testdata/TestAgents/call3.golden

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
"tools": [
55
{
66
"function": {
7-
"toolID": "testdata/TestAgents/test.gpt:agent1",
8-
"name": "agent1",
7+
"toolID": "testdata/TestAgents/test.gpt:agent3",
8+
"name": "agent3",
99
"parameters": {
1010
"properties": {
1111
"defaultPromptParameter": {
@@ -19,8 +19,8 @@
1919
},
2020
{
2121
"function": {
22-
"toolID": "testdata/TestAgents/test.gpt:agent3",
23-
"name": "agent3",
22+
"toolID": "testdata/TestAgents/test.gpt:agent1",
23+
"name": "agent1",
2424
"parameters": {
2525
"properties": {
2626
"defaultPromptParameter": {

Diff for: pkg/tests/testdata/TestAgents/step1.golden

+6-6
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@
178178
"tools": [
179179
{
180180
"function": {
181-
"toolID": "testdata/TestAgents/test.gpt:agent1",
182-
"name": "agent1",
181+
"toolID": "testdata/TestAgents/test.gpt:agent3",
182+
"name": "agent3",
183183
"parameters": {
184184
"properties": {
185185
"defaultPromptParameter": {
@@ -193,8 +193,8 @@
193193
},
194194
{
195195
"function": {
196-
"toolID": "testdata/TestAgents/test.gpt:agent3",
197-
"name": "agent3",
196+
"toolID": "testdata/TestAgents/test.gpt:agent1",
197+
"name": "agent1",
198198
"parameters": {
199199
"properties": {
200200
"defaultPromptParameter": {
@@ -222,7 +222,7 @@
222222
"content": [
223223
{
224224
"toolCall": {
225-
"index": 1,
225+
"index": 0,
226226
"id": "call_3",
227227
"function": {
228228
"name": "agent3"
@@ -237,7 +237,7 @@
237237
},
238238
"pending": {
239239
"call_3": {
240-
"index": 1,
240+
"index": 0,
241241
"id": "call_3",
242242
"function": {
243243
"name": "agent3"

Diff for: pkg/tests/testdata/TestExport/call1-resp.golden

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"content": [
44
{
55
"toolCall": {
6-
"index": 2,
6+
"index": 1,
77
"id": "call_1",
88
"function": {
99
"name": "transient"

Diff for: pkg/tests/testdata/TestExport/call1.golden

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
},
1919
{
2020
"function": {
21-
"toolID": "testdata/TestExport/parent.gpt:parent-local",
22-
"name": "parentLocal",
21+
"toolID": "testdata/TestExport/sub/child.gpt:transient",
22+
"name": "transient",
2323
"parameters": {
2424
"properties": {
2525
"defaultPromptParameter": {
@@ -33,8 +33,8 @@
3333
},
3434
{
3535
"function": {
36-
"toolID": "testdata/TestExport/sub/child.gpt:transient",
37-
"name": "transient",
36+
"toolID": "testdata/TestExport/parent.gpt:parent-local",
37+
"name": "parentLocal",
3838
"parameters": {
3939
"properties": {
4040
"defaultPromptParameter": {

Diff for: pkg/tests/testdata/TestExport/call3.golden

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
},
1919
{
2020
"function": {
21-
"toolID": "testdata/TestExport/parent.gpt:parent-local",
22-
"name": "parentLocal",
21+
"toolID": "testdata/TestExport/sub/child.gpt:transient",
22+
"name": "transient",
2323
"parameters": {
2424
"properties": {
2525
"defaultPromptParameter": {
@@ -33,8 +33,8 @@
3333
},
3434
{
3535
"function": {
36-
"toolID": "testdata/TestExport/sub/child.gpt:transient",
37-
"name": "transient",
36+
"toolID": "testdata/TestExport/parent.gpt:parent-local",
37+
"name": "parentLocal",
3838
"parameters": {
3939
"properties": {
4040
"defaultPromptParameter": {
@@ -62,7 +62,7 @@
6262
"content": [
6363
{
6464
"toolCall": {
65-
"index": 2,
65+
"index": 1,
6666
"id": "call_1",
6767
"function": {
6868
"name": "transient"
@@ -80,7 +80,7 @@
8080
}
8181
],
8282
"toolCall": {
83-
"index": 2,
83+
"index": 1,
8484
"id": "call_1",
8585
"function": {
8686
"name": "transient"

Diff for: pkg/tests/testdata/TestExportContext/call1.golden

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"role": "system",
3939
"content": [
4040
{
41-
"text": "this is from external context\nthis is from context\nThis is from tool"
41+
"text": "this is from context\nthis is from external context\nThis is from tool"
4242
}
4343
],
4444
"usage": {}

Diff for: pkg/tests/testdata/TestToolRefAll/call1.golden

+9-9
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
},
1919
{
2020
"function": {
21-
"toolID": "testdata/TestToolRefAll/test.gpt:none",
22-
"name": "none",
21+
"toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant",
22+
"name": "agentAssistant",
2323
"parameters": {
2424
"properties": {
25-
"noneArg": {
26-
"description": "stuff",
25+
"defaultPromptParameter": {
26+
"description": "Prompt to send to the tool. This may be an instruction or question.",
2727
"type": "string"
2828
}
2929
},
@@ -33,12 +33,12 @@
3333
},
3434
{
3535
"function": {
36-
"toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant",
37-
"name": "agent",
36+
"toolID": "testdata/TestToolRefAll/test.gpt:none",
37+
"name": "none",
3838
"parameters": {
3939
"properties": {
40-
"defaultPromptParameter": {
41-
"description": "Prompt to send to the tool. This may be an instruction or question.",
40+
"noneArg": {
41+
"description": "stuff",
4242
"type": "string"
4343
}
4444
},
@@ -52,7 +52,7 @@
5252
"role": "system",
5353
"content": [
5454
{
55-
"text": "\nShared context\n\nContext Body\nMain tool"
55+
"text": "\nContext Body\n\nShared context\nMain tool"
5656
}
5757
],
5858
"usage": {}

Diff for: pkg/types/completion.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ import (
99
)
1010

1111
type CompletionRequest struct {
12-
Model string `json:"model,omitempty"`
13-
InternalSystemPrompt *bool `json:"internalSystemPrompt,omitempty"`
14-
Tools []CompletionTool `json:"tools,omitempty"`
15-
Messages []CompletionMessage `json:"messages,omitempty"`
16-
MaxTokens int `json:"maxTokens,omitempty"`
17-
Chat bool `json:"chat,omitempty"`
18-
Temperature *float32 `json:"temperature,omitempty"`
19-
JSONResponse bool `json:"jsonResponse,omitempty"`
20-
Cache *bool `json:"cache,omitempty"`
12+
Model string `json:"model,omitempty"`
13+
InternalSystemPrompt *bool `json:"internalSystemPrompt,omitempty"`
14+
Tools []ChatCompletionTool `json:"tools,omitempty"`
15+
Messages []CompletionMessage `json:"messages,omitempty"`
16+
MaxTokens int `json:"maxTokens,omitempty"`
17+
Chat bool `json:"chat,omitempty"`
18+
Temperature *float32 `json:"temperature,omitempty"`
19+
JSONResponse bool `json:"jsonResponse,omitempty"`
20+
Cache *bool `json:"cache,omitempty"`
2121
}
2222

2323
func (r *CompletionRequest) GetCache() bool {
@@ -27,7 +27,7 @@ func (r *CompletionRequest) GetCache() bool {
2727
return *r.Cache
2828
}
2929

30-
type CompletionTool struct {
30+
type ChatCompletionTool struct {
3131
Function CompletionFunctionDefinition `json:"function,omitempty"`
3232
}
3333

Diff for: pkg/types/set.go

+11
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ func (t *toolRefSet) List() (result []ToolReference, err error) {
1919
return result, t.err
2020
}
2121

22+
func (t *toolRefSet) Contains(value ToolReference) bool {
23+
key := toolRefKey{
24+
name: value.Named,
25+
toolID: value.ToolID,
26+
arg: value.Arg,
27+
}
28+
29+
_, ok := t.set[key]
30+
return ok
31+
}
32+
2233
func (t *toolRefSet) HasTool(toolID string) bool {
2334
for _, ref := range t.set {
2435
if ref.ToolID == toolID {

0 commit comments

Comments
 (0)