25
25
import ai .djl .util .PairList ;
26
26
import ai .djl .util .StringPair ;
27
27
28
+ import com .google .gson .JsonArray ;
28
29
import com .google .gson .JsonElement ;
30
+ import com .google .gson .JsonObject ;
29
31
import com .google .gson .JsonParseException ;
30
- import com .google .gson .reflect .TypeToken ;
31
32
32
- import java .lang . reflect . Type ;
33
+ import java .util . ArrayList ;
33
34
import java .util .List ;
34
35
35
36
/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
36
37
public class CrossEncoderServingTranslator implements NoBatchifyTranslator <Input , Output > {
37
38
38
- private static final Type LIST_TYPE = new TypeToken <List <StringPair >>() {}.getType ();
39
-
40
39
private Translator <StringPair , float []> translator ;
41
40
42
41
/**
@@ -63,31 +62,65 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
63
62
}
64
63
65
64
String contentType = input .getProperty ("Content-Type" , null );
66
- StringPair pair ;
65
+ if (contentType != null ) {
66
+ int pos = contentType .indexOf (';' );
67
+ if (pos > 0 ) {
68
+ contentType = contentType .substring (0 , pos );
69
+ }
70
+ }
71
+ StringPair pair = null ;
67
72
if ("application/json" .equals (contentType )) {
68
73
String json = input .getData ().getAsString ();
69
74
try {
70
75
JsonElement element = JsonUtils .GSON .fromJson (json , JsonElement .class );
71
76
if (element .isJsonArray ()) {
72
77
ctx .setAttachment ("batch" , Boolean .TRUE );
73
- List <StringPair > inputs = JsonUtils .GSON .fromJson (json , LIST_TYPE );
78
+ JsonArray array = element .getAsJsonArray ();
79
+ int size = array .size ();
80
+ List <StringPair > inputs = new ArrayList <>(size );
81
+ for (int i = 0 ; i < size ; ++i ) {
82
+ JsonObject obj = array .get (i ).getAsJsonObject ();
83
+ inputs .add (parseStringPair (obj ));
84
+ }
74
85
return translator .batchProcessInput (ctx , inputs );
75
- }
76
-
77
- pair = JsonUtils .GSON .fromJson (json , StringPair .class );
78
- if (pair .getKey () == null || pair .getValue () == null ) {
79
- throw new TranslateException ("Missing key or value in json." );
86
+ } else if (element .isJsonObject ()) {
87
+ JsonObject obj = element .getAsJsonObject ();
88
+ JsonElement query = obj .get ("query" );
89
+ if (query != null ) {
90
+ String key = query .getAsString ();
91
+ JsonArray texts = obj .get ("texts" ).getAsJsonArray ();
92
+ int size = texts .size ();
93
+ List <StringPair > inputs = new ArrayList <>(size );
94
+ for (int i = 0 ; i < size ; ++i ) {
95
+ String value = texts .get (i ).getAsString ();
96
+ inputs .add (new StringPair (key , value ));
97
+ }
98
+ ctx .setAttachment ("batch" , Boolean .TRUE );
99
+ return translator .batchProcessInput (ctx , inputs );
100
+ } else {
101
+ pair = parseStringPair (obj );
102
+ }
103
+ } else {
104
+ throw new TranslateException ("Unexpected json type" );
80
105
}
81
106
} catch (JsonParseException e ) {
82
107
throw new TranslateException ("Input is not a valid json." , e );
83
108
}
84
109
} else {
110
+ String text = input .getAsString ("text" );
111
+ String textPair = input .getAsString ("text_pair" );
112
+ if (text != null && textPair != null ) {
113
+ pair = new StringPair (text , textPair );
114
+ }
85
115
String key = input .getAsString ("key" );
86
116
String value = input .getAsString ("value" );
87
- if (key == null || value = = null ) {
88
- throw new TranslateException ( "Missing key or value in input." );
117
+ if (key != null && value ! = null ) {
118
+ pair = new StringPair ( key , value );
89
119
}
90
- pair = new StringPair (key , value );
120
+ }
121
+
122
+ if (pair == null ) {
123
+ throw new TranslateException ("Missing key or value in input." );
91
124
}
92
125
93
126
NDList ret = translator .processInput (ctx , pair );
@@ -115,4 +148,18 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
115
148
}
116
149
return output ;
117
150
}
151
+
152
+ private StringPair parseStringPair (JsonObject json ) throws TranslateException {
153
+ JsonElement text = json .get ("text" );
154
+ JsonElement textPair = json .get ("text_pair" );
155
+ if (text != null && textPair != null ) {
156
+ return new StringPair (text .getAsString (), textPair .getAsString ());
157
+ }
158
+ JsonElement key = json .get ("key" );
159
+ JsonElement value = json .get ("value" );
160
+ if (key != null && value != null ) {
161
+ return new StringPair (key .getAsString (), value .getAsString ());
162
+ }
163
+ throw new TranslateException ("Missing text or text_pair in json." );
164
+ }
118
165
}
0 commit comments