Skip to content

Commit f356013

Browse files
committed
fix: stop running tool for providers that are runnning
A previous change stopped caching clients so that they would be restarted whenever needed. However, the running of the tool produces unwanted output if the provider is already running. This change includes a way to tell whether the provider is running or needs to be restarted. Signed-off-by: Donnie Adams <[email protected]>
1 parent b05ceb7 commit f356013

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

Diff for: pkg/engine/daemon.go

+13-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ import (
1818
var ports Ports
1919

2020
type Ports struct {
21-
daemonPorts map[string]int64
22-
daemonLock sync.Mutex
21+
daemonPorts map[string]int64
22+
daemonsRunning map[string]struct{}
23+
daemonLock sync.Mutex
2324

2425
startPort, endPort int64
2526
usedPorts map[int64]struct{}
@@ -28,6 +29,13 @@ type Ports struct {
2829
daemonWG sync.WaitGroup
2930
}
3031

32+
func IsDaemonRunning(url string) bool {
33+
ports.daemonLock.Lock()
34+
defer ports.daemonLock.Unlock()
35+
_, ok := ports.daemonsRunning[url]
36+
return ok
37+
}
38+
3139
func SetPorts(start, end int64) {
3240
ports.daemonLock.Lock()
3341
defer ports.daemonLock.Unlock()
@@ -164,8 +172,10 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
164172

165173
if ports.daemonPorts == nil {
166174
ports.daemonPorts = map[string]int64{}
175+
ports.daemonsRunning = map[string]struct{}{}
167176
}
168177
ports.daemonPorts[tool.ID] = port
178+
ports.daemonsRunning[url] = struct{}{}
169179

170180
killedCtx, cancel := context.WithCancelCause(ctx)
171181
defer cancel(nil)
@@ -185,6 +195,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
185195
defer ports.daemonLock.Unlock()
186196

187197
delete(ports.daemonPorts, tool.ID)
198+
delete(ports.daemonsRunning, url)
188199
ports.daemonWG.Done()
189200
}()
190201

Diff for: pkg/remote/remote.go

+31-11
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import (
2222
)
2323

2424
type Client struct {
25-
modelsLock sync.Mutex
25+
clientsLock sync.Mutex
2626
cache *cache.Client
27+
clients map[string]clientInfo
2728
modelToProvider map[string]string
2829
runner *runner.Runner
2930
envs []string
@@ -38,13 +39,15 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
3839
envs: envs,
3940
credStore: credStore,
4041
defaultProvider: defaultProvider,
42+
modelToProvider: make(map[string]string),
43+
clients: make(map[string]clientInfo),
4144
}
4245
}
4346

4447
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
45-
c.modelsLock.Lock()
48+
c.clientsLock.Lock()
4649
provider, ok := c.modelToProvider[messageRequest.Model]
47-
c.modelsLock.Unlock()
50+
c.clientsLock.Unlock()
4851

4952
if !ok {
5053
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
@@ -105,12 +108,8 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
105108
return false, err
106109
}
107110

108-
c.modelsLock.Lock()
109-
defer c.modelsLock.Unlock()
110-
111-
if c.modelToProvider == nil {
112-
c.modelToProvider = map[string]string{}
113-
}
111+
c.clientsLock.Lock()
112+
defer c.clientsLock.Unlock()
114113

115114
c.modelToProvider[modelString] = providerName
116115
return true, nil
@@ -145,11 +144,23 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
145144
}
146145

147146
func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
147+
c.clientsLock.Lock()
148+
defer c.clientsLock.Unlock()
149+
150+
client, ok := c.clients[toolName]
151+
if ok && !isHTTPURL(toolName) && engine.IsDaemonRunning(client.url) {
152+
return client.client, nil
153+
}
154+
148155
if isHTTPURL(toolName) {
149156
remoteClient, err := c.clientFromURL(ctx, toolName)
150157
if err != nil {
151158
return nil, err
152159
}
160+
c.clients[toolName] = clientInfo{
161+
client: remoteClient,
162+
url: toolName,
163+
}
153164
return remoteClient, nil
154165
}
155166

@@ -165,7 +176,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
165176
return nil, err
166177
}
167178

168-
client, err := openai.NewClient(ctx, c.credStore, openai.Options{
179+
oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{
169180
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
170181
Cache: c.cache,
171182
CacheKey: prg.EntryToolID,
@@ -174,7 +185,11 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174185
return nil, err
175186
}
176187

177-
return client, nil
188+
c.clients[toolName] = clientInfo{
189+
client: oClient,
190+
url: url,
191+
}
192+
return client.client, nil
178193
}
179194

180195
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
@@ -185,3 +200,8 @@ func isLocalhost(url string) bool {
185200
return strings.HasPrefix(url, "http://localhost") || strings.HasPrefix(url, "http://127.0.0.1") ||
186201
strings.HasPrefix(url, "https://localhost") || strings.HasPrefix(url, "https://127.0.0.1")
187202
}
203+
204+
type clientInfo struct {
205+
client *openai.Client
206+
url string
207+
}

0 commit comments

Comments
 (0)