@@ -22,8 +22,9 @@ import (
22
22
)
23
23
24
24
type Client struct {
25
- modelsLock sync.Mutex
25
+ clientsLock sync.Mutex
26
26
cache * cache.Client
27
+ clients map [string ]clientInfo
27
28
modelToProvider map [string ]string
28
29
runner * runner.Runner
29
30
envs []string
@@ -38,13 +39,15 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
38
39
envs : envs ,
39
40
credStore : credStore ,
40
41
defaultProvider : defaultProvider ,
42
+ modelToProvider : make (map [string ]string ),
43
+ clients : make (map [string ]clientInfo ),
41
44
}
42
45
}
43
46
44
47
func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
45
- c .modelsLock .Lock ()
48
+ c .clientsLock .Lock ()
46
49
provider , ok := c .modelToProvider [messageRequest .Model ]
47
- c .modelsLock .Unlock ()
50
+ c .clientsLock .Unlock ()
48
51
49
52
if ! ok {
50
53
return nil , fmt .Errorf ("failed to find remote model %s" , messageRequest .Model )
@@ -105,12 +108,8 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
105
108
return false , err
106
109
}
107
110
108
- c .modelsLock .Lock ()
109
- defer c .modelsLock .Unlock ()
110
-
111
- if c .modelToProvider == nil {
112
- c .modelToProvider = map [string ]string {}
113
- }
111
+ c .clientsLock .Lock ()
112
+ defer c .clientsLock .Unlock ()
114
113
115
114
c .modelToProvider [modelString ] = providerName
116
115
return true , nil
@@ -145,11 +144,23 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
145
144
}
146
145
147
146
func (c * Client ) load (ctx context.Context , toolName string ) (* openai.Client , error ) {
147
+ c .clientsLock .Lock ()
148
+ defer c .clientsLock .Unlock ()
149
+
150
+ client , ok := c .clients [toolName ]
151
+ if ok && ! isHTTPURL (toolName ) && engine .IsDaemonRunning (client .url ) {
152
+ return client .client , nil
153
+ }
154
+
148
155
if isHTTPURL (toolName ) {
149
156
remoteClient , err := c .clientFromURL (ctx , toolName )
150
157
if err != nil {
151
158
return nil , err
152
159
}
160
+ c .clients [toolName ] = clientInfo {
161
+ client : remoteClient ,
162
+ url : toolName ,
163
+ }
153
164
return remoteClient , nil
154
165
}
155
166
@@ -165,7 +176,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
165
176
return nil , err
166
177
}
167
178
168
- client , err := openai .NewClient (ctx , c .credStore , openai.Options {
179
+ oClient , err := openai .NewClient (ctx , c .credStore , openai.Options {
169
180
BaseURL : strings .TrimSuffix (url , "/" ) + "/v1" ,
170
181
Cache : c .cache ,
171
182
CacheKey : prg .EntryToolID ,
@@ -174,7 +185,11 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174
185
return nil , err
175
186
}
176
187
177
- return client , nil
188
+ c .clients [toolName ] = clientInfo {
189
+ client : oClient ,
190
+ url : url ,
191
+ }
192
+ return client .client , nil
178
193
}
179
194
180
195
func (c * Client ) retrieveAPIKey (ctx context.Context , env , url string ) (string , error ) {
@@ -185,3 +200,8 @@ func isLocalhost(url string) bool {
185
200
return strings .HasPrefix (url , "http://localhost" ) || strings .HasPrefix (url , "http://127.0.0.1" ) ||
186
201
strings .HasPrefix (url , "https://localhost" ) || strings .HasPrefix (url , "https://127.0.0.1" )
187
202
}
203
+
204
+ type clientInfo struct {
205
+ client * openai.Client
206
+ url string
207
+ }
0 commit comments