Skip to content

Commit d952fa5

Browse files
authored
bugfix: plugin will block GET request (#1428)
1 parent e7561c3 commit d952fa5

File tree

1 file changed

+7
-6
lines changed
  • plugins/wasm-go/extensions/ai-security-guard

1 file changed

+7
-6
lines changed

plugins/wasm-go/extensions/ai-security-guard/main.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,6 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
187187
log.Debugf("request checking is disabled")
188188
ctx.DontReadRequestBody()
189189
}
190-
if !config.checkResponse {
191-
log.Debugf("response checking is disabled")
192-
ctx.DontReadResponseBody()
193-
}
194190
return types.ActionContinue
195191
}
196192

@@ -199,7 +195,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
199195
content := gjson.GetBytes(body, config.requestContentJsonPath).Raw
200196
model := gjson.GetBytes(body, "model").Raw
201197
ctx.SetContext("requestModel", model)
202-
log.Debugf("Raw response content is: %s", content)
198+
log.Debugf("Raw request content is: %s", content)
203199
if len(content) > 0 {
204200
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
205201
randomID, _ := generateHexID(16)
@@ -321,6 +317,11 @@ func reconvertHeaders(hs map[string][]string) [][2]string {
321317
}
322318

323319
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
320+
if !config.checkResponse {
321+
log.Debugf("response checking is disabled")
322+
ctx.DontReadResponseBody()
323+
return types.ActionContinue
324+
}
324325
headers, err := proxywasm.GetHttpResponseHeaders()
325326
if err != nil {
326327
log.Warnf("failed to get response headers: %v", err)
@@ -399,7 +400,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
399400
var jsonData []byte
400401
if config.protocolOriginal {
401402
jsonData = []byte(denyMessage)
402-
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
403+
} else if isStreamingResponse {
403404
randomID := generateRandomID()
404405
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
405406
} else {

0 commit comments

Comments
 (0)