|
1 | 1 | #! /usr/bin/python
|
2 | 2 | # -*- coding: utf-8 -*-
|
3 | 3 |
|
| 4 | +################################ TensorLayerX and Jittor. ################################# |
| 5 | + |
4 | 6 | import os
|
5 | 7 | import time
|
6 |
| -import numpy as np |
7 | 8 | import tensorlayerx as tlx
|
8 | 9 | from tensorlayerx.dataflow import Dataset, DataLoader
|
9 | 10 | from tensorlayerx.vision.transforms import (
|
10 | 11 | Compose, Resize, RandomFlipHorizontal, RandomContrast, RandomBrightness, StandardizePerImage, RandomCrop
|
11 | 12 | )
|
12 |
| -from tensorlayerx.nn import Conv2d, Linear, Flatten, Module |
| 13 | +from tensorlayerx.nn import Conv2d, Linear, Flatten, Module, MaxPool2d, BatchNorm2d |
13 | 14 | from tensorlayerx.optimizers import Adam
|
14 | 15 | from tqdm import tqdm
|
15 | 16 |
|
|
18 | 19 |
|
19 | 20 | os.environ['TL_BACKEND'] = 'jittor'
|
20 | 21 |
|
21 |
| - |
22 |
| - |
23 |
| -# Download and prepare the CIFAR10 dataset with progress bar |
| 22 | +# Download and prepare the CIFAR10 dataset |
24 | 23 | print("Downloading CIFAR10 dataset...")
|
25 | 24 | X_train, y_train, X_test, y_test = tlx.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
|
26 | 25 |
|
@@ -59,58 +58,54 @@ def __len__(self):
|
59 | 58 | train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
|
60 | 59 | test_dataloader = DataLoader(test_dataset, batch_size=128)
|
61 | 60 |
|
62 |
| -# Define a simple CNN model |
| 61 | + |
63 | 62 | class SimpleCNN(Module):
|
64 | 63 | def __init__(self):
|
65 | 64 | super(SimpleCNN, self).__init__()
|
66 | 65 | self.conv1 = Conv2d(16, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=3)
|
| 66 | + self.conv2 = Conv2d(32, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=16) |
| 67 | + self.maxpool1 = MaxPool2d((2, 2), (2, 2), padding='SAME') |
| 68 | + self.conv3 = Conv2d(64, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=32) |
| 69 | + self.bn1 = BatchNorm2d(num_features=64, act=tlx.nn.ReLU) |
| 70 | + self.conv4 = Conv2d(128, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=64) |
| 71 | + self.maxpool2 = MaxPool2d((2, 2), (2, 2), padding='SAME') |
67 | 72 | self.flatten = Flatten()
|
68 |
| - self.fc1 = Linear(out_features=64, act=tlx.nn.ReLU, in_features=16 * 24 * 24) |
69 |
| - self.fc2 = Linear(out_features=10, act=None, in_features=64) |
| 73 | + self.fc1 = Linear(out_features=128, act=tlx.nn.ReLU, in_features=128 * 6 * 6) |
| 74 | + self.fc2 = Linear(out_features=64, act=tlx.nn.ReLU, in_features=128) |
| 75 | + self.fc3 = Linear(out_features=10, act=None, in_features=64) |
70 | 76 |
|
71 | 77 | def forward(self, x):
|
72 | 78 | z = self.conv1(x)
|
| 79 | + z = self.conv2(z) |
| 80 | + z = self.maxpool1(z) |
| 81 | + z = self.conv3(z) |
| 82 | + z = self.bn1(z) |
| 83 | + z = self.conv4(z) |
| 84 | + z = self.maxpool2(z) |
73 | 85 | z = self.flatten(z)
|
74 | 86 | z = self.fc1(z)
|
75 | 87 | z = self.fc2(z)
|
| 88 | + z = self.fc3(z) |
76 | 89 | return z
|
77 | 90 |
|
| 91 | + |
78 | 92 | # Instantiate the model
|
79 | 93 | model = SimpleCNN()
|
80 | 94 |
|
81 | 95 | # Define the optimizer
|
82 |
| -optimizer = Adam(model.trainable_weights, lr=0.001) |
| 96 | +optimizer = Adam(lr=0.001) |
| 97 | +# optimizer = Adam(lr=0.001, params=model.trainable_weights ) |
83 | 98 |
|
84 | 99 | # Define the loss function
|
85 | 100 | loss_fn = tlx.losses.softmax_cross_entropy_with_logits
|
86 | 101 |
|
87 |
| -# Training loop |
88 |
| -n_epoch = 2 |
89 |
| -for epoch in range(n_epoch): |
90 |
| - start_time = time.time() |
91 |
| - model.set_train() |
92 |
| - train_loss, n_iter = 0, 0 |
93 |
| - |
94 |
| - with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{n_epoch}", unit="batch") as pbar: |
95 |
| - for X_batch, y_batch in train_dataloader: |
96 |
| - X_batch = tlx.convert_to_tensor(X_batch) |
97 |
| - y_batch = tlx.convert_to_tensor(y_batch) |
98 |
| - _logits = model(X_batch) |
99 |
| - loss = loss_fn(_logits, y_batch) |
100 |
| - |
101 |
| - optimizer.zero_grad() |
102 |
| - optimizer.step(loss) |
103 |
| - |
104 |
| - train_loss += loss.item() |
105 |
| - n_iter += 1 |
106 |
| - pbar.update(1) |
107 |
| - |
108 |
| - print(f"Epoch {epoch + 1} of {n_epoch} took {time.time() - start_time:.2f}s") |
109 |
| - print(f" train loss: {train_loss / n_iter:.4f}") |
110 |
| - |
111 |
| - |
112 |
| - |
113 |
| -################################ TensorLayerX and Jittor can be mixed programming. ################################# |
| 102 | +# Use the built-in training method |
| 103 | +metric = tlx.metrics.Recall() |
| 104 | +tlx_model = tlx.model.Model(network=model, loss_fn=loss_fn, optimizer=optimizer, metrics=metric) |
| 105 | +tlx_model.train(n_epoch=2, train_dataset=train_dataloader, print_freq=1, print_train_batch=True) |
| 106 | + |
| 107 | + |
| 108 | +################################ TensorLayerX and Torch. ################################# |
114 | 109 |
|
115 | 110 |
|
116 | 111 |
|
|
0 commit comments