Skip to content

Commit 7ac48dd

Browse files
authored
[example] Enable PyTorch for some training example (#3398)
1 parent 11ff0c1 commit 7ac48dd

File tree

5 files changed

+19
-21
lines changed

5 files changed

+19
-21
lines changed

examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,22 @@ public class TrainCaptchaTest {
2727
public void testTrainCaptcha() throws IOException, TranslateException {
2828
TestRequirements.linux();
2929

30-
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
30+
// TODO: PyTorch
31+
/*
32+
ai.djl.engine.EngineException: index 11 is out of bounds for dimension 1 with size 11
33+
at app//ai.djl.pytorch.jni.PyTorchLibrary.torchGather(Native Method)
34+
at app//ai.djl.pytorch.jni.JniUtils.pick(JniUtils.java:581)
35+
at app//ai.djl.pytorch.jni.JniUtils.indexAdv(JniUtils.java:417)
36+
at app//ai.djl.pytorch.engine.PtNDArrayIndexer.get(PtNDArrayIndexer.java:74)
37+
at app//ai.djl.ndarray.NDArray.get(NDArray.java:614)
38+
at app//ai.djl.ndarray.NDArray.get(NDArray.java:603)
39+
at app//ai.djl.training.loss.SoftmaxCrossEntropyLoss.evaluate(SoftmaxCrossEntropyLoss.java:86)
40+
at app//ai.djl.training.loss.IndexLoss.evaluate(IndexLoss.java:55)
41+
at app//ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:68)
42+
at app//ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:124)
43+
at app//ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110)
44+
at app//ai.djl.training.EasyTrain.fit(EasyTrain.java:58)
45+
*/
3146
String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
3247
TrainingResult result = TrainCaptcha.runExample(args);
3348
Assert.assertNotNull(result);

examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java

+1-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
*/
1313
package ai.djl.examples.training;
1414

15-
import ai.djl.engine.Engine;
1615
import ai.djl.training.TrainingResult;
1716
import ai.djl.translate.TranslateException;
1817

@@ -25,14 +24,7 @@ public class TrainMnistWithLSTMTest {
2524

2625
@Test
2726
public void testTrainMnistWithLSTM() throws IOException, TranslateException {
28-
String[] args;
29-
Engine engine = Engine.getEngine("PyTorch");
30-
if (engine.getGpuCount() > 0) {
31-
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
32-
args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
33-
} else {
34-
args = new String[] {"-g", "1", "-e", "1", "-m", "2"};
35-
}
27+
String[] args = {"-g", "1", "-e", "1", "-m", "2"};
3628
TrainingResult result = TrainMnistWithLSTM.runExample(args);
3729
Assert.assertNotNull(result);
3830
}

examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java

-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ public void testTrainResNet() throws ModelException, IOException, TranslateExcep
3636

3737
// Limit max 4 gpu for cifar10 training to make it converge faster.
3838
// and only train 10 batch for unit test.
39-
// only MXNet support symbolic model
4039
String[] args = {"-e", "2", "-g", "4", "-m", "10", "-p"};
4140
TrainingResult result = TrainResnetWithCifar10.runExample(args);
4241

examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public void testTrainSentimentAnalysis()
2929
TestRequirements.nightly();
3030
TestRequirements.gpu("MXNet", 1);
3131

32+
// TODO: Add a PyTorch Glove model to model zoo
3233
String[] args = {"-e", "1", "-g", "1", "--engine", "MXNet"};
3334
TrainSentimentAnalysis.runExample(args);
3435
}

examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java

+1-10
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
package ai.djl.examples.training;
1515

16-
import ai.djl.engine.Engine;
1716
import ai.djl.training.TrainingResult;
1817
import ai.djl.translate.TranslateException;
1918

@@ -26,15 +25,7 @@ public class TrainTimeSeriesTest {
2625

2726
@Test
2827
public void testTrainTimeSeries() throws TranslateException, IOException {
29-
String[] args;
30-
Engine engine = Engine.getEngine("PyTorch");
31-
if (engine.getGpuCount() > 0) {
32-
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
33-
args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"};
34-
} else {
35-
args = new String[] {"-g", "1", "-e", "5", "-b", "32"};
36-
}
37-
28+
String[] args = {"-g", "1", "-e", "5", "-b", "32"};
3829
TrainingResult result = TrainTimeSeries.runExample(args);
3930
Assert.assertNotNull(result);
4031
float loss = result.getTrainLoss();

0 commit comments

Comments
 (0)