Skip to content

Commit 80f3f6c

Browse files
add_Jittor: Passing model tests, Parameter and Module Container test
add_Jittor: Passing model tests, Parameter and Module Container test. Additional Functionality: 1- TrainOneStep integration. 2- Updated core/train_jt to enable accuracy to be measured. 3- Updated Jittor Optimizer: replaced gradient and apply_gradient function with jittors default functions Zero_grad() and Step(). included a new function Set() to set the trainable_weights paramters for the optimizer. 4- Updated Jittor Metrics for Accuracy, Recall, Precision and AUC. 5- Creating Jittor model tutorial file jittor_module_tutorial.py 6- Module Container and Parameter Container: Updated core_jittor ModuleList and ParameterDict to enable OrderedDict intialization which was not available due to the parent class (Jittor Module) initializing Dict by default which caused integration issues. This issue was handled by updating the function and also excluding the parent Module for these functions. Areas to optimize integration: Enabling Jittor integration to run large model training as currently it is limited in the complexity of NN layers.
1 parent 772af7a commit 80f3f6c

26 files changed

+853
-852
lines changed

examples/basic_tutorials/cifar10_cnn.py

+31-36
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4+
################################ TensorLayerX and Jittor. #################################
5+
46
import os
57
import time
6-
import numpy as np
78
import tensorlayerx as tlx
89
from tensorlayerx.dataflow import Dataset, DataLoader
910
from tensorlayerx.vision.transforms import (
1011
Compose, Resize, RandomFlipHorizontal, RandomContrast, RandomBrightness, StandardizePerImage, RandomCrop
1112
)
12-
from tensorlayerx.nn import Conv2d, Linear, Flatten, Module
13+
from tensorlayerx.nn import Conv2d, Linear, Flatten, Module, MaxPool2d, BatchNorm2d
1314
from tensorlayerx.optimizers import Adam
1415
from tqdm import tqdm
1516

@@ -18,9 +19,7 @@
1819

1920
os.environ['TL_BACKEND'] = 'jittor'
2021

21-
22-
23-
# Download and prepare the CIFAR10 dataset with progress bar
22+
# Download and prepare the CIFAR10 dataset
2423
print("Downloading CIFAR10 dataset...")
2524
X_train, y_train, X_test, y_test = tlx.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
2625

@@ -59,58 +58,54 @@ def __len__(self):
5958
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
6059
test_dataloader = DataLoader(test_dataset, batch_size=128)
6160

62-
# Define a simple CNN model
61+
6362
class SimpleCNN(Module):
6463
def __init__(self):
6564
super(SimpleCNN, self).__init__()
6665
self.conv1 = Conv2d(16, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=3)
66+
self.conv2 = Conv2d(32, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=16)
67+
self.maxpool1 = MaxPool2d((2, 2), (2, 2), padding='SAME')
68+
self.conv3 = Conv2d(64, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=32)
69+
self.bn1 = BatchNorm2d(num_features=64, act=tlx.nn.ReLU)
70+
self.conv4 = Conv2d(128, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=64)
71+
self.maxpool2 = MaxPool2d((2, 2), (2, 2), padding='SAME')
6772
self.flatten = Flatten()
68-
self.fc1 = Linear(out_features=64, act=tlx.nn.ReLU, in_features=16 * 24 * 24)
69-
self.fc2 = Linear(out_features=10, act=None, in_features=64)
73+
self.fc1 = Linear(out_features=128, act=tlx.nn.ReLU, in_features=128 * 6 * 6)
74+
self.fc2 = Linear(out_features=64, act=tlx.nn.ReLU, in_features=128)
75+
self.fc3 = Linear(out_features=10, act=None, in_features=64)
7076

7177
def forward(self, x):
7278
z = self.conv1(x)
79+
z = self.conv2(z)
80+
z = self.maxpool1(z)
81+
z = self.conv3(z)
82+
z = self.bn1(z)
83+
z = self.conv4(z)
84+
z = self.maxpool2(z)
7385
z = self.flatten(z)
7486
z = self.fc1(z)
7587
z = self.fc2(z)
88+
z = self.fc3(z)
7689
return z
7790

91+
7892
# Instantiate the model
7993
model = SimpleCNN()
8094

8195
# Define the optimizer
82-
optimizer = Adam(model.trainable_weights, lr=0.001)
96+
optimizer = Adam(lr=0.001)
97+
# optimizer = Adam(lr=0.001, params=model.trainable_weights )
8398

8499
# Define the loss function
85100
loss_fn = tlx.losses.softmax_cross_entropy_with_logits
86101

87-
# Training loop
88-
n_epoch = 2
89-
for epoch in range(n_epoch):
90-
start_time = time.time()
91-
model.set_train()
92-
train_loss, n_iter = 0, 0
93-
94-
with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{n_epoch}", unit="batch") as pbar:
95-
for X_batch, y_batch in train_dataloader:
96-
X_batch = tlx.convert_to_tensor(X_batch)
97-
y_batch = tlx.convert_to_tensor(y_batch)
98-
_logits = model(X_batch)
99-
loss = loss_fn(_logits, y_batch)
100-
101-
optimizer.zero_grad()
102-
optimizer.step(loss)
103-
104-
train_loss += loss.item()
105-
n_iter += 1
106-
pbar.update(1)
107-
108-
print(f"Epoch {epoch + 1} of {n_epoch} took {time.time() - start_time:.2f}s")
109-
print(f" train loss: {train_loss / n_iter:.4f}")
110-
111-
112-
113-
################################ TensorLayerX and Jittor can be mixed programming. #################################
102+
# Use the built-in training method
103+
metric = tlx.metrics.Recall()
104+
tlx_model = tlx.model.Model(network=model, loss_fn=loss_fn, optimizer=optimizer, metrics=metric)
105+
tlx_model.train(n_epoch=2, train_dataset=train_dataloader, print_freq=1, print_train_batch=True)
106+
107+
108+
################################ TensorLayerX and Torch. #################################
114109

115110

116111

examples/basic_tutorials/cifar10_cnn_dist.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# -*- coding: utf-8 -*-
33

44
import os
5-
os.environ['TL_BACKEND'] = 'paddle'
5+
# os.environ['TL_BACKEND'] = 'paddle'
6+
os.environ['TL_BACKEND'] = 'jittor'
67
# os.environ['TL_BACKEND'] = 'tensorflow'
78
# os.environ['TL_BACKEND'] = 'mindspore'
89
# os.environ['TL_BACKEND'] = 'torch'

examples/basic_tutorials/cifar10_cnn_train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
os.environ['TL_BACKEND'] = 'jittor'
1010
# os.environ['TL_BACKEND'] = 'tensorflow'
1111
# os.environ['TL_BACKEND'] = 'mindspore'
12-
1312
# os.environ['TL_BACKEND'] = 'torch'
1413

1514

@@ -76,7 +75,7 @@ def forward(self, x):
7675

7776
# 定义损失函数、优化器等
7877
loss_fn=tlx.losses.softmax_cross_entropy_with_logits
79-
optimizer = tlx.optimizers.Adam(net.trainable_weights, lr=learning_rate)
78+
optimizer = tlx.optimizers.Adam(lr=learning_rate)
8079
metrics = tlx.metrics.Accuracy()
8180

8281

examples/basic_tutorials/gradient_clip_mixed_tensorflow.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# -*- coding: utf-8 -*-
33
# The tensorlayerx and tensorflow operators can be mixed
44
import os
5-
os.environ['TL_BACKEND'] = 'tensorflow'
5+
# os.environ['TL_BACKEND'] = 'tensorflow'
66
# os.environ['TL_BACKEND'] = 'paddle'
77
# os.environ['TL_BACKEND'] = 'torch'
8+
os.environ['TL_BACKEND'] = 'jittor'
89

910

1011
import time

0 commit comments

Comments
 (0)