Skip to content

Commit 11ff0c1

Browse files
authored
[api] Support TEI input format to reranking model (#3400)
1 parent c4baffb commit 11ff0c1

File tree

6 files changed

+119
-16
lines changed

6 files changed

+119
-16
lines changed

api/src/main/java/ai/djl/modality/nlp/TextPrompt.java

+6
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ public List<String> getBatch() {
7171
*/
7272
public static TextPrompt parseInput(Input input) throws TranslateException {
7373
String contentType = input.getProperty("Content-Type", null);
74+
if (contentType != null) {
75+
int pos = contentType.indexOf(';');
76+
if (pos > 0) {
77+
contentType = contentType.substring(0, pos);
78+
}
79+
}
7480
String text = input.getData().getAsString();
7581
if (!"application/json".equals(contentType)) {
7682
return new TextPrompt(text);

api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java

+61-14
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,17 @@
2525
import ai.djl.util.PairList;
2626
import ai.djl.util.StringPair;
2727

28+
import com.google.gson.JsonArray;
2829
import com.google.gson.JsonElement;
30+
import com.google.gson.JsonObject;
2931
import com.google.gson.JsonParseException;
30-
import com.google.gson.reflect.TypeToken;
3132

32-
import java.lang.reflect.Type;
33+
import java.util.ArrayList;
3334
import java.util.List;
3435

3536
/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
3637
public class CrossEncoderServingTranslator implements NoBatchifyTranslator<Input, Output> {
3738

38-
private static final Type LIST_TYPE = new TypeToken<List<StringPair>>() {}.getType();
39-
4039
private Translator<StringPair, float[]> translator;
4140

4241
/**
@@ -63,31 +62,65 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
6362
}
6463

6564
String contentType = input.getProperty("Content-Type", null);
66-
StringPair pair;
65+
if (contentType != null) {
66+
int pos = contentType.indexOf(';');
67+
if (pos > 0) {
68+
contentType = contentType.substring(0, pos);
69+
}
70+
}
71+
StringPair pair = null;
6772
if ("application/json".equals(contentType)) {
6873
String json = input.getData().getAsString();
6974
try {
7075
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
7176
if (element.isJsonArray()) {
7277
ctx.setAttachment("batch", Boolean.TRUE);
73-
List<StringPair> inputs = JsonUtils.GSON.fromJson(json, LIST_TYPE);
78+
JsonArray array = element.getAsJsonArray();
79+
int size = array.size();
80+
List<StringPair> inputs = new ArrayList<>(size);
81+
for (int i = 0; i < size; ++i) {
82+
JsonObject obj = array.get(i).getAsJsonObject();
83+
inputs.add(parseStringPair(obj));
84+
}
7485
return translator.batchProcessInput(ctx, inputs);
75-
}
76-
77-
pair = JsonUtils.GSON.fromJson(json, StringPair.class);
78-
if (pair.getKey() == null || pair.getValue() == null) {
79-
throw new TranslateException("Missing key or value in json.");
86+
} else if (element.isJsonObject()) {
87+
JsonObject obj = element.getAsJsonObject();
88+
JsonElement query = obj.get("query");
89+
if (query != null) {
90+
String key = query.getAsString();
91+
JsonArray texts = obj.get("texts").getAsJsonArray();
92+
int size = texts.size();
93+
List<StringPair> inputs = new ArrayList<>(size);
94+
for (int i = 0; i < size; ++i) {
95+
String value = texts.get(i).getAsString();
96+
inputs.add(new StringPair(key, value));
97+
}
98+
ctx.setAttachment("batch", Boolean.TRUE);
99+
return translator.batchProcessInput(ctx, inputs);
100+
} else {
101+
pair = parseStringPair(obj);
102+
}
103+
} else {
104+
throw new TranslateException("Unexpected json type");
80105
}
81106
} catch (JsonParseException e) {
82107
throw new TranslateException("Input is not a valid json.", e);
83108
}
84109
} else {
110+
String text = input.getAsString("text");
111+
String textPair = input.getAsString("text_pair");
112+
if (text != null && textPair != null) {
113+
pair = new StringPair(text, textPair);
114+
}
85115
String key = input.getAsString("key");
86116
String value = input.getAsString("value");
87-
if (key == null || value == null) {
88-
throw new TranslateException("Missing key or value in input.");
117+
if (key != null && value != null) {
118+
pair = new StringPair(key, value);
89119
}
90-
pair = new StringPair(key, value);
120+
}
121+
122+
if (pair == null) {
123+
throw new TranslateException("Missing key or value in input.");
91124
}
92125

93126
NDList ret = translator.processInput(ctx, pair);
@@ -115,4 +148,18 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
115148
}
116149
return output;
117150
}
151+
152+
private StringPair parseStringPair(JsonObject json) throws TranslateException {
153+
JsonElement text = json.get("text");
154+
JsonElement textPair = json.get("text_pair");
155+
if (text != null && textPair != null) {
156+
return new StringPair(text.getAsString(), textPair.getAsString());
157+
}
158+
JsonElement key = json.get("key");
159+
JsonElement value = json.get("value");
160+
if (key != null && value != null) {
161+
return new StringPair(key.getAsString(), value.getAsString());
162+
}
163+
throw new TranslateException("Missing text or text_pair in json.");
164+
}
118165
}

api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java

+6
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
6565
}
6666

6767
String contentType = input.getProperty("Content-Type", null);
68+
if (contentType != null) {
69+
int pos = contentType.indexOf(';');
70+
if (pos > 0) {
71+
contentType = contentType.substring(0, pos);
72+
}
73+
}
6874
QAInput qa;
6975
if ("application/json".equals(contentType)) {
7076
String json = input.getData().getAsString();

extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java

+44
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
import java.nio.file.Path;
3939
import java.nio.file.Paths;
4040
import java.util.HashMap;
41+
import java.util.List;
4142
import java.util.Map;
4243

4344
public class CrossEncoderTranslatorTest {
4445

4546
@Test
47+
@SuppressWarnings("unchecked")
4648
public void testCrossEncoderTranslator()
4749
throws ModelException, IOException, TranslateException {
4850
String text1 = "Sentence 1";
@@ -119,6 +121,48 @@ public void testCrossEncoderTranslator()
119121
float[] buf = (float[]) res.getData().getAsObject();
120122
Assert.assertEquals(buf[0], 0.32455865, 0.0001);
121123

124+
input = new Input();
125+
input.add("text", text1);
126+
input.add("text_pair", text2);
127+
res = predictor.predict(input);
128+
buf = (float[]) res.getData().getAsObject();
129+
Assert.assertEquals(buf[0], 0.32455865, 0.0001);
130+
131+
input = new Input();
132+
input.addProperty("Content-Type", "application/json; charset=utf-8");
133+
input.add("data", "{\"text\": \"" + text1 + "\", \"text_pair\": \"" + text2 + "\"}");
134+
res = predictor.predict(input);
135+
buf = (float[]) res.getData().getAsObject();
136+
Assert.assertEquals(buf[0], 0.32455865, 0.0001);
137+
138+
input = new Input();
139+
input.addProperty("Content-Type", "application/json; charset=utf-8");
140+
input.add("data", "{\"key\": \"" + text1 + "\", \"value\": \"" + text2 + "\"}");
141+
res = predictor.predict(input);
142+
buf = (float[]) res.getData().getAsObject();
143+
Assert.assertEquals(buf[0], 0.32455865, 0.0001);
144+
145+
input = new Input();
146+
input.addProperty("Content-Type", "application/json");
147+
input.add("data", "{\"query\": \"" + text1 + "\", \"texts\": [\"" + text2 + "\"]}");
148+
res = predictor.predict(input);
149+
buf = ((List<float[]>) res.getData().getAsObject()).get(0);
150+
Assert.assertEquals(buf[0], 0.32455865, 0.0001);
151+
152+
input = new Input();
153+
input.addProperty("Content-Type", "application/json");
154+
input.add("data", "{\"query\": \"" + text1 + "\", \"texts\": [\"" + text2 + "\"]}");
155+
res = predictor.predict(input);
156+
buf = ((List<float[]>) res.getData().getAsObject()).get(0);
157+
Assert.assertEquals(buf[0], 0.32455865, 0.0001);
158+
159+
input = new Input();
160+
input.addProperty("Content-Type", "application/json");
161+
input.add("data", "[{\"text\": \"" + text1 + "\", \"text_pair\": \"" + text2 + "\"}]");
162+
res = predictor.predict(input);
163+
buf = ((List<float[]>) res.getData().getAsObject()).get(0);
164+
Assert.assertEquals(buf[0], 0.32455865, 0.0001);
165+
122166
Assert.assertThrows(TranslateException.class, () -> predictor.predict(new Input()));
123167

124168
Assert.assertThrows(

extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/QuestionAnsweringTranslatorTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public void testQATranslator() throws ModelException, IOException, TranslateExce
123123
TranslateException.class,
124124
() -> {
125125
Input req = new Input();
126-
req.addProperty("Content-Type", "application/json");
126+
req.addProperty("Content-Type", "application/json; charset=utf-8");
127127
req.add("Invalid json");
128128
predictor.predict(req);
129129
});

extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ public void testTextEmbeddingTranslator()
213213
Map<String, String> map = new HashMap<>();
214214
map.put("inputs", text);
215215
input.add(JsonUtils.GSON.toJson(map));
216-
input.addProperty("Content-Type", "application/json");
216+
input.addProperty("Content-Type", "application/json; charset=utf-8");
217217
out = predictor.predict(input);
218218
res = (float[]) out.getData().getAsObject();
219219
Assert.assertEquals(res.length, 384);

0 commit comments

Comments
 (0)