Skip to content

Commit f92e8dc

Browse files
authored
Merge pull request #784 from thedadams/no-evens-for-running-providers
fix: stop running tool for providers that are runnning
2 parents b05ceb7 + f356013 commit f92e8dc

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

Diff for: pkg/engine/daemon.go

+13-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ import (
1818
var ports Ports
1919

2020
type Ports struct {
21-
daemonPorts map[string]int64
22-
daemonLock sync.Mutex
21+
daemonPorts map[string]int64
22+
daemonsRunning map[string]struct{}
23+
daemonLock sync.Mutex
2324

2425
startPort, endPort int64
2526
usedPorts map[int64]struct{}
@@ -28,6 +29,13 @@ type Ports struct {
2829
daemonWG sync.WaitGroup
2930
}
3031

32+
func IsDaemonRunning(url string) bool {
33+
ports.daemonLock.Lock()
34+
defer ports.daemonLock.Unlock()
35+
_, ok := ports.daemonsRunning[url]
36+
return ok
37+
}
38+
3139
func SetPorts(start, end int64) {
3240
ports.daemonLock.Lock()
3341
defer ports.daemonLock.Unlock()
@@ -164,8 +172,10 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
164172

165173
if ports.daemonPorts == nil {
166174
ports.daemonPorts = map[string]int64{}
175+
ports.daemonsRunning = map[string]struct{}{}
167176
}
168177
ports.daemonPorts[tool.ID] = port
178+
ports.daemonsRunning[url] = struct{}{}
169179

170180
killedCtx, cancel := context.WithCancelCause(ctx)
171181
defer cancel(nil)
@@ -185,6 +195,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
185195
defer ports.daemonLock.Unlock()
186196

187197
delete(ports.daemonPorts, tool.ID)
198+
delete(ports.daemonsRunning, url)
188199
ports.daemonWG.Done()
189200
}()
190201

Diff for: pkg/remote/remote.go

+31-11
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import (
2222
)
2323

2424
type Client struct {
25-
modelsLock sync.Mutex
25+
clientsLock sync.Mutex
2626
cache *cache.Client
27+
clients map[string]clientInfo
2728
modelToProvider map[string]string
2829
runner *runner.Runner
2930
envs []string
@@ -38,13 +39,15 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
3839
envs: envs,
3940
credStore: credStore,
4041
defaultProvider: defaultProvider,
42+
modelToProvider: make(map[string]string),
43+
clients: make(map[string]clientInfo),
4144
}
4245
}
4346

4447
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
45-
c.modelsLock.Lock()
48+
c.clientsLock.Lock()
4649
provider, ok := c.modelToProvider[messageRequest.Model]
47-
c.modelsLock.Unlock()
50+
c.clientsLock.Unlock()
4851

4952
if !ok {
5053
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
@@ -105,12 +108,8 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
105108
return false, err
106109
}
107110

108-
c.modelsLock.Lock()
109-
defer c.modelsLock.Unlock()
110-
111-
if c.modelToProvider == nil {
112-
c.modelToProvider = map[string]string{}
113-
}
111+
c.clientsLock.Lock()
112+
defer c.clientsLock.Unlock()
114113

115114
c.modelToProvider[modelString] = providerName
116115
return true, nil
@@ -145,11 +144,23 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
145144
}
146145

147146
func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
147+
c.clientsLock.Lock()
148+
defer c.clientsLock.Unlock()
149+
150+
client, ok := c.clients[toolName]
151+
if ok && !isHTTPURL(toolName) && engine.IsDaemonRunning(client.url) {
152+
return client.client, nil
153+
}
154+
148155
if isHTTPURL(toolName) {
149156
remoteClient, err := c.clientFromURL(ctx, toolName)
150157
if err != nil {
151158
return nil, err
152159
}
160+
c.clients[toolName] = clientInfo{
161+
client: remoteClient,
162+
url: toolName,
163+
}
153164
return remoteClient, nil
154165
}
155166

@@ -165,7 +176,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
165176
return nil, err
166177
}
167178

168-
client, err := openai.NewClient(ctx, c.credStore, openai.Options{
179+
oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{
169180
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
170181
Cache: c.cache,
171182
CacheKey: prg.EntryToolID,
@@ -174,7 +185,11 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174185
return nil, err
175186
}
176187

177-
return client, nil
188+
c.clients[toolName] = clientInfo{
189+
client: oClient,
190+
url: url,
191+
}
192+
return client.client, nil
178193
}
179194

180195
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
@@ -185,3 +200,8 @@ func isLocalhost(url string) bool {
185200
return strings.HasPrefix(url, "http://localhost") || strings.HasPrefix(url, "http://127.0.0.1") ||
186201
strings.HasPrefix(url, "https://localhost") || strings.HasPrefix(url, "https://127.0.0.1")
187202
}
203+
204+
type clientInfo struct {
205+
client *openai.Client
206+
url string
207+
}

0 commit comments

Comments
 (0)