Skip to content

Commit 029c3e7

Browse files
authored
optimization parseIP in xff (#1915)
1 parent 9fa3a73 commit 029c3e7

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

plugins/wasm-go/extensions/ip-restriction/main.go

+3-6
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ package main
33
import (
44
"encoding/json"
55
"fmt"
6+
"net"
7+
68
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
79
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
810
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
911
"github.com/tidwall/gjson"
1012
"github.com/zmap/go-iptree/iptree"
11-
"net"
12-
"strings"
1313
)
1414

1515
const (
@@ -101,9 +101,6 @@ func getDownStreamIp(config RestrictionConfig) (net.IP, error) {
101101

102102
if config.IPSourceType == HeaderSourceType {
103103
s, err = proxywasm.GetHttpRequestHeader(config.IPHeaderName)
104-
if err == nil {
105-
s = strings.Split(strings.Trim(s, " "), ",")[0]
106-
}
107104
} else {
108105
var bs []byte
109106
bs, err = proxywasm.GetProperty([]string{"source", "address"})
@@ -112,7 +109,7 @@ func getDownStreamIp(config RestrictionConfig) (net.IP, error) {
112109
if err != nil {
113110
return nil, err
114111
}
115-
ip := parseIP(s)
112+
ip := parseIP(s, config.IPSourceType == HeaderSourceType)
116113
realIP := net.ParseIP(ip)
117114
if realIP == nil {
118115
return nil, fmt.Errorf("invalid ip[%s]", ip)

plugins/wasm-go/extensions/ip-restriction/utils.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package main
22

33
import (
44
"fmt"
5+
"strings"
6+
57
"github.com/tidwall/gjson"
68
"github.com/zmap/go-iptree/iptree"
7-
"strings"
89
)
910

1011
// parseIPNets 解析Ip段配置
@@ -24,7 +25,12 @@ func parseIPNets(array []gjson.Result) (*iptree.IPTree, error) {
2425
}
2526

2627
// parseIP 解析IP
27-
func parseIP(source string) string {
28+
func parseIP(source string, fromHeader bool) string {
29+
30+
if fromHeader {
31+
source = strings.Split(source, ",")[0]
32+
}
33+
source = strings.Trim(source, " ")
2834
if strings.Contains(source, ".") {
2935
// parse ipv4
3036
return strings.Split(source, ":")[0]

plugins/wasm-go/extensions/ip-restriction/utils_test.go

+34-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package main
22

33
import (
4-
"github.com/tidwall/gjson"
54
"testing"
5+
6+
"github.com/tidwall/gjson"
67
)
78

89
func Test_parseIPNets(t *testing.T) {
@@ -52,7 +53,8 @@ func Test_parseIPNets(t *testing.T) {
5253

5354
func Test_parseIP(t *testing.T) {
5455
type args struct {
55-
source string
56+
source string
57+
fromHeader bool
5658
}
5759
tests := []struct {
5860
name string
@@ -64,41 +66,70 @@ func Test_parseIP(t *testing.T) {
6466
name: "case 1",
6567
args: args{
6668
"127.0.0.1",
69+
false,
6770
},
6871
want: "127.0.0.1",
6972
},
7073
{
7174
name: "case 2",
7275
args: args{
7376
"127.0.0.1:12",
77+
false,
7478
},
7579
want: "127.0.0.1",
7680
},
7781
{
7882
name: "case 3",
7983
args: args{
8084
"fe80::14d5:8aff:fed9:2114",
85+
false,
8186
},
8287
want: "fe80::14d5:8aff:fed9:2114",
8388
},
8489
{
8590
name: "case 4",
8691
args: args{
8792
"[fe80::14d5:8aff:fed9:2114]:123",
93+
false,
8894
},
8995
want: "fe80::14d5:8aff:fed9:2114",
9096
},
9197
{
9298
name: "case 5",
9399
args: args{
94100
"127.0.0.1:12,[fe80::14d5:8aff:fed9:2114]:123",
101+
true,
102+
},
103+
want: "127.0.0.1",
104+
},
105+
{
106+
name: "case 6",
107+
args: args{
108+
"127.0.0.1,[fe80::14d5:8aff:fed9:2114]:123",
109+
true,
110+
},
111+
want: "127.0.0.1",
112+
},
113+
{
114+
name: "case 7",
115+
args: args{
116+
"[fe80::14d5:8aff:fed9:2114]:123,127.0.0.1",
117+
true,
118+
},
119+
want: "fe80::14d5:8aff:fed9:2114",
120+
},
121+
{
122+
name: "case 8",
123+
args: args{
124+
"127.0.0.1 , [fe80::14d5:8aff:fed9:2114]:123",
125+
true,
95126
},
96127
want: "127.0.0.1",
97128
},
98129
}
99130
for _, tt := range tests {
100131
t.Run(tt.name, func(t *testing.T) {
101-
if got := parseIP(tt.args.source); got != tt.want {
132+
if got := parseIP(tt.args.source, tt.args.fromHeader); got != tt.want {
102133
t.Errorf("parseIP() = %v, want %v", got, tt.want)
103134
}
104135
})

0 commit comments

Comments
 (0)