Skip to content

Commit 59252f6

Browse files
johnrengelmanJohn Engelman
authored and
John Engelman
committed
feat: Add support for tools from github enterprise.
1 parent 4fd8e8a commit 59252f6

File tree

3 files changed

+145
-43
lines changed

3 files changed

+145
-43
lines changed

Diff for: pkg/cli/gptscript.go

+25-19
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/gptscript-ai/gptscript/pkg/gptscript"
2424
"github.com/gptscript-ai/gptscript/pkg/input"
2525
"github.com/gptscript-ai/gptscript/pkg/loader"
26+
"github.com/gptscript-ai/gptscript/pkg/loader/github"
2627
"github.com/gptscript-ai/gptscript/pkg/monitor"
2728
"github.com/gptscript-ai/gptscript/pkg/mvl"
2829
"github.com/gptscript-ai/gptscript/pkg/openai"
@@ -54,25 +55,26 @@ type GPTScript struct {
5455
Output string `usage:"Save output to a file, or - for stdout" short:"o"`
5556
EventsStreamTo string `usage:"Stream events to this location, could be a file descriptor/handle (e.g. fd://2), filename, or named pipe (e.g. \\\\.\\pipe\\my-pipe)" name:"events-stream-to"`
5657
// Input should not be using GPTSCRIPT_INPUT env var because that is the same value that is set in tool executions
57-
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
58-
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
59-
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
60-
ListModels bool `usage:"List the models available and exit" local:"true"`
61-
ListTools bool `usage:"List built-in tools and exit" local:"true"`
62-
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
63-
Chdir string `usage:"Change current working directory" short:"C"`
64-
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
65-
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
66-
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
67-
CredentialOverride []string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
68-
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
69-
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
70-
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
71-
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
72-
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
73-
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
74-
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
75-
DefaultModelProvider string `usage:"Default LLM model provider to use, this will override OpenAI settings"`
58+
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
59+
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
60+
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
61+
ListModels bool `usage:"List the models available and exit" local:"true"`
62+
ListTools bool `usage:"List built-in tools and exit" local:"true"`
63+
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
64+
Chdir string `usage:"Change current working directory" short:"C"`
65+
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
66+
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
67+
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
68+
CredentialOverride []string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
69+
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
70+
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
71+
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
72+
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
73+
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
74+
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
75+
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
76+
DefaultModelProvider string `usage:"Default LLM model provider to use, this will override OpenAI settings"`
77+
GithubEnterpriseHostname string `usage:"The host name for a Github Enterprise instance to enable for remote loading" local:"true"`
7678

7779
readData []byte
7880
}
@@ -334,6 +336,10 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
334336
return err
335337
}
336338

339+
if r.GithubEnterpriseHostname != "" {
340+
loader.AddVSC(github.LoaderForPrefix(r.GithubEnterpriseHostname))
341+
}
342+
337343
// If the user is trying to launch the chat-builder UI, then set up the tool and options here.
338344
if r.UI {
339345
if os.Getenv(system.BinEnvVar) == "" {

Diff for: pkg/loader/github/github.go

+63-24
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package github
22

33
import (
44
"context"
5+
"crypto/tls"
56
"encoding/json"
67
"fmt"
78
"io"
@@ -18,52 +19,63 @@ import (
1819
"github.com/gptscript-ai/gptscript/pkg/types"
1920
)
2021

21-
const (
22-
GithubPrefix = "github.com/"
23-
githubRepoURL = "https://github.com/%s/%s.git"
24-
githubDownloadURL = "https://raw.githubusercontent.com/%s/%s/%s/%s"
25-
githubCommitURL = "https://api.github.com/repos/%s/%s/commits/%s"
26-
)
22+
type GithubConfig struct {
23+
Prefix string
24+
RepoURL string
25+
DownloadURL string
26+
CommitURL string
27+
AuthToken string
28+
}
2729

2830
var (
29-
githubAuthToken = os.Getenv("GITHUB_AUTH_TOKEN")
30-
log = mvl.Package()
31+
log = mvl.Package()
32+
defaultGithubConfig = &GithubConfig{
33+
Prefix: "github.com/",
34+
RepoURL: "https://github.com/%s/%s.git",
35+
DownloadURL: "https://raw.githubusercontent.com/%s/%s/%s/%s",
36+
CommitURL: "https://api.github.com/repos/%s/%s/commits/%s",
37+
AuthToken: os.Getenv("GITHUB_AUTH_TOKEN"),
38+
}
3139
)
3240

3341
func init() {
3442
loader.AddVSC(Load)
3543
}
3644

37-
func getCommitLsRemote(ctx context.Context, account, repo, ref string) (string, error) {
38-
url := fmt.Sprintf(githubRepoURL, account, repo)
45+
func getCommitLsRemote(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) {
46+
url := fmt.Sprintf(config.RepoURL, account, repo)
3947
return git.LsRemote(ctx, url, ref)
4048
}
4149

4250
// regexp to match a git commit id
4351
var commitRegexp = regexp.MustCompile("^[a-f0-9]{40}$")
4452

45-
func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
53+
func getCommit(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) {
4654
if commitRegexp.MatchString(ref) {
4755
return ref, nil
4856
}
4957

50-
url := fmt.Sprintf(githubCommitURL, account, repo, ref)
58+
url := fmt.Sprintf(config.CommitURL, account, repo, ref)
5159
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
5260
if err != nil {
5361
return "", fmt.Errorf("failed to create request of %s/%s at %s: %w", account, repo, url, err)
5462
}
5563

56-
if githubAuthToken != "" {
57-
req.Header.Add("Authorization", "Bearer "+githubAuthToken)
64+
if config.AuthToken != "" {
65+
req.Header.Add("Authorization", "Bearer "+config.AuthToken)
5866
}
5967

60-
resp, err := http.DefaultClient.Do(req)
68+
client := http.DefaultClient
69+
if req.Host == config.Prefix && strings.ToLower(os.Getenv("GH_ENTERPRISE_SKIP_VERIFY")) == "true" {
70+
client = &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}}
71+
}
72+
resp, err := client.Do(req)
6173
if err != nil {
6274
return "", err
6375
} else if resp.StatusCode != http.StatusOK {
6476
c, _ := io.ReadAll(resp.Body)
6577
resp.Body.Close()
66-
commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref)
78+
commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref, config)
6779
if fallBackErr == nil {
6880
return commit, nil
6981
}
@@ -88,8 +100,28 @@ func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
88100
return commit.SHA, nil
89101
}
90102

91-
func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string, *types.Repo, bool, error) {
92-
if !strings.HasPrefix(urlName, GithubPrefix) {
103+
func LoaderForPrefix(prefix string) func(context.Context, *cache.Client, string) (string, string, *types.Repo, bool, error) {
104+
return func(ctx context.Context, c *cache.Client, urlName string) (string, string, *types.Repo, bool, error) {
105+
return LoadWithConfig(ctx, c, urlName, NewGithubEnterpriseConfig(prefix))
106+
}
107+
}
108+
109+
func Load(ctx context.Context, c *cache.Client, urlName string) (string, string, *types.Repo, bool, error) {
110+
return LoadWithConfig(ctx, c, urlName, defaultGithubConfig)
111+
}
112+
113+
func NewGithubEnterpriseConfig(prefix string) *GithubConfig {
114+
return &GithubConfig{
115+
Prefix: prefix,
116+
RepoURL: fmt.Sprintf("https://%s/%%s/%%s.git", prefix),
117+
DownloadURL: fmt.Sprintf("https://raw.%s/%%s/%%s/%%s/%%s", prefix),
118+
CommitURL: fmt.Sprintf("https://%s/api/v3/repos/%%s/%%s/commits/%%s", prefix),
119+
AuthToken: os.Getenv("GH_ENTERPRISE_TOKEN"),
120+
}
121+
}
122+
123+
func LoadWithConfig(ctx context.Context, _ *cache.Client, urlName string, config *GithubConfig) (string, string, *types.Repo, bool, error) {
124+
if !strings.HasPrefix(urlName, config.Prefix) {
93125
return "", "", nil, false, nil
94126
}
95127

@@ -107,12 +139,12 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
107139
account, repo := parts[1], parts[2]
108140
path := strings.Join(parts[3:], "/")
109141

110-
ref, err := getCommit(ctx, account, repo, ref)
142+
ref, err := getCommit(ctx, account, repo, ref, config)
111143
if err != nil {
112144
return "", "", nil, false, err
113145
}
114146

115-
downloadURL := fmt.Sprintf(githubDownloadURL, account, repo, ref, path)
147+
downloadURL := fmt.Sprintf(config.DownloadURL, account, repo, ref, path)
116148
if path == "" || path == "/" || !strings.Contains(parts[len(parts)-1], ".") {
117149
var (
118150
testPath string
@@ -124,13 +156,20 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
124156
} else {
125157
testPath = path + "/" + ext
126158
}
127-
testURL = fmt.Sprintf(githubDownloadURL, account, repo, ref, testPath)
159+
testURL = fmt.Sprintf(config.DownloadURL, account, repo, ref, testPath)
128160
if i == len(types.DefaultFiles)-1 {
129161
// no reason to test the last one, we are just going to use it. Being that the default list is only
130162
// two elements this loop could have been one check, but hey over-engineered code ftw.
131163
break
132164
}
133-
if resp, err := http.Head(testURL); err == nil {
165+
headReq, err := http.NewRequest("HEAD", testURL, nil)
166+
if err != nil {
167+
break
168+
}
169+
if config.AuthToken != "" {
170+
headReq.Header.Add("Authorization", "Bearer "+config.AuthToken)
171+
}
172+
if resp, err := http.DefaultClient.Do(headReq); err == nil {
134173
_ = resp.Body.Close()
135174
if resp.StatusCode == 200 {
136175
break
@@ -141,9 +180,9 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
141180
path = testPath
142181
}
143182

144-
return downloadURL, githubAuthToken, &types.Repo{
183+
return downloadURL, config.AuthToken, &types.Repo{
145184
VCS: "git",
146-
Root: fmt.Sprintf(githubRepoURL, account, repo),
185+
Root: fmt.Sprintf(config.RepoURL, account, repo),
147186
Path: gpath.Dir(path),
148187
Name: gpath.Base(path),
149188
Revision: ref,

Diff for: pkg/loader/github/github_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package github
22

33
import (
44
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"os"
59
"testing"
610

711
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -44,3 +48,56 @@ func TestLoad(t *testing.T) {
4448
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
4549
}).Equal(t, repo)
4650
}
51+
52+
func TestLoad_GithubEnterprise(t *testing.T) {
53+
gheToken := "mytoken"
54+
os.Setenv("GH_ENTERPRISE_SKIP_VERIFY", "true")
55+
os.Setenv("GH_ENTERPRISE_TOKEN", gheToken)
56+
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57+
switch r.URL.Path {
58+
case "/api/v3/repos/gptscript-ai/gptscript/commits/172dfb0":
59+
w.Write([]byte(`{"sha": "172dfb00b48c6adbbaa7e99270933f95887d1b91"}`))
60+
default:
61+
w.WriteHeader(404)
62+
}
63+
}))
64+
defer s.Close()
65+
66+
serverAddr := s.Listener.Addr().String()
67+
68+
url, token, repo, ok, err := LoadWithConfig(context.Background(), nil, fmt.Sprintf("%s/gptscript-ai/gptscript/pkg/loader/testdata/tool@172dfb0", serverAddr), NewGithubEnterpriseConfig(serverAddr))
69+
require.NoError(t, err)
70+
assert.True(t, ok)
71+
autogold.Expect(fmt.Sprintf("https://raw.%s/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/tool/tool.gpt", serverAddr)).Equal(t, url)
72+
autogold.Expect(&types.Repo{
73+
VCS: "git", Root: fmt.Sprintf("https://%s/gptscript-ai/gptscript.git", serverAddr),
74+
Path: "pkg/loader/testdata/tool",
75+
Name: "tool.gpt",
76+
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
77+
}).Equal(t, repo)
78+
autogold.Expect(gheToken).Equal(t, token)
79+
80+
url, token, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/agent@172dfb0")
81+
require.NoError(t, err)
82+
assert.True(t, ok)
83+
autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/agent/agent.gpt").Equal(t, url)
84+
autogold.Expect(&types.Repo{
85+
VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git",
86+
Path: "pkg/loader/testdata/agent",
87+
Name: "agent.gpt",
88+
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
89+
}).Equal(t, repo)
90+
autogold.Expect("").Equal(t, token)
91+
92+
url, token, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/bothtoolagent@172dfb0")
93+
require.NoError(t, err)
94+
assert.True(t, ok)
95+
autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/bothtoolagent/agent.gpt").Equal(t, url)
96+
autogold.Expect(&types.Repo{
97+
VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git",
98+
Path: "pkg/loader/testdata/bothtoolagent",
99+
Name: "agent.gpt",
100+
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
101+
}).Equal(t, repo)
102+
autogold.Expect("").Equal(t, token)
103+
}

0 commit comments

Comments
 (0)