|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +import os |
| 5 | +# os.environ['TL_BACKEND'] = 'paddle' |
| 6 | +# os.environ['TL_BACKEND'] = 'tensorflow' |
| 7 | +# os.environ['TL_BACKEND'] = 'mindspore' |
| 8 | +os.environ['TL_BACKEND'] = 'torch' |
| 9 | + |
| 10 | +import time |
| 11 | +from tensorlayerx.dataflow import Dataset, DataLoader |
| 12 | +from tensorlayerx.vision.transforms import ( |
| 13 | + Compose, Resize, RandomFlipHorizontal, RandomContrast, RandomBrightness, StandardizePerImage, RandomCrop |
| 14 | +) |
| 15 | +from tensorlayerx.model import TrainOneStep |
| 16 | +from tensorlayerx.nn import Module |
| 17 | +import tensorlayerx as tlx |
| 18 | +from tensorlayerx.nn import (Conv2d, Linear, Flatten, MaxPool2d, BatchNorm2d) |
| 19 | +import argparse |
| 20 | + |
| 21 | +parser = argparse.ArgumentParser() |
| 22 | +parser.add_argument("--local_rank", type=int, default=-1, |
| 23 | + help="For distributed training: local_rank") |
| 24 | +args = parser.parse_args() |
| 25 | +# enable debug logging |
| 26 | +tlx.logging.set_verbosity(tlx.logging.DEBUG) |
| 27 | + |
| 28 | +tlx.ops.set_device(device = 'MLU', id = args.local_rank) |
| 29 | +tlx.ops.distributed_init(backend="cncl") |
| 30 | +# ################## Download and prepare the CIFAR10 dataset ################## |
| 31 | +# This is just some way of getting the CIFAR10 dataset from an online location |
| 32 | +# and loading it into numpy arrays with shape [32,32,3] |
| 33 | +X_train, y_train, X_test, y_test = tlx.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) |
| 34 | + |
| 35 | +# ################## CIFAR10 dataset ################## |
| 36 | +# We define a Dataset class for Loading CIFAR10 images and labels. |
| 37 | +class make_dataset(Dataset): |
| 38 | + |
| 39 | + def __init__(self, data, label, transforms): |
| 40 | + self.data = data |
| 41 | + self.label = label |
| 42 | + self.transforms = transforms |
| 43 | + |
| 44 | + def __getitem__(self, idx): |
| 45 | + x = self.data[idx].astype('uint8') |
| 46 | + y = self.label[idx].astype('int64') |
| 47 | + x = self.transforms(x) |
| 48 | + |
| 49 | + return x, y |
| 50 | + |
| 51 | + def __len__(self): |
| 52 | + |
| 53 | + return len(self.label) |
| 54 | + |
| 55 | +# We define the CIFAR10 iamges preprocessing pipeline. |
| 56 | +train_transforms = Compose( # Combining multiple operations sequentially |
| 57 | + [ |
| 58 | + RandomCrop(size=[24, 24]), #random crop from images to shape [24, 24] |
| 59 | + RandomFlipHorizontal(), # random invert each image horizontally by probability |
| 60 | + RandomBrightness(brightness_factor=(0.5, 1.5)), # Within the range of values (0.5, 1.5), adjust brightness randomly |
| 61 | + RandomContrast(contrast_factor=(0.5, 1.5)), # Within the range of values (0.5, 1.5), adjust contrast randomly |
| 62 | + StandardizePerImage() #Normalize the values of each image to [-1, 1] |
| 63 | + ] |
| 64 | +) |
| 65 | + |
| 66 | +test_transforms = Compose([Resize(size=(24, 24)), StandardizePerImage()]) |
| 67 | + |
| 68 | +# We use DataLoader to batch and shuffle data, and make data into iterators. |
| 69 | +train_dataset = make_dataset(data=X_train, label=y_train, transforms=train_transforms) |
| 70 | +test_dataset = make_dataset(data=X_test, label=y_test, transforms=test_transforms) |
| 71 | + |
| 72 | +train_dataset = DataLoader(train_dataset, batch_size=128, shuffle=True) |
| 73 | +test_dataset = DataLoader(test_dataset, batch_size=128) |
| 74 | + |
| 75 | +# ################## CNN network ################## |
| 76 | +class CNN(Module): |
| 77 | + |
| 78 | + def __init__(self): |
| 79 | + super(CNN, self).__init__() |
| 80 | + # Parameter initialization method |
| 81 | + W_init = tlx.nn.initializers.truncated_normal(stddev=5e-2) |
| 82 | + W_init2 = tlx.nn.initializers.truncated_normal(stddev=0.04) |
| 83 | + b_init2 = tlx.nn.initializers.constant(value=0.1) |
| 84 | + |
| 85 | + # 2D Convolutional Neural Network, Set padding method "SAME", convolutional kernel size [5,5], stride [1,1], in channels, out channels |
| 86 | + self.conv1 = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv1', in_channels=3) |
| 87 | + # Add 2D BatchNormalize, using ReLU for output. |
| 88 | + self.bn = BatchNorm2d(num_features=64, act=tlx.nn.ReLU) |
| 89 | + # Add 2D Max pooling layer. |
| 90 | + self.maxpool1 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1') |
| 91 | + |
| 92 | + self.conv2 = Conv2d( |
| 93 | + 64, (5, 5), (1, 1), padding='SAME', act=tlx.nn.ReLU, W_init=W_init, name='conv2', in_channels=64 |
| 94 | + ) |
| 95 | + self.maxpool2 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2') |
| 96 | + # Flatten 2D data to 1D data |
| 97 | + self.flatten = Flatten(name='flatten') |
| 98 | + # Linear layer with 384 units, using ReLU for output. |
| 99 | + self.linear1 = Linear(384, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear1relu', in_features=2304) |
| 100 | + self.linear2 = Linear(192, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear2relu', in_features=384) |
| 101 | + self.linear3 = Linear(10, act=None, W_init=W_init2, name='output', in_features=192) |
| 102 | + |
| 103 | + # We define the forward computation process. |
| 104 | + def forward(self, x): |
| 105 | + z = self.conv1(x) |
| 106 | + z = self.bn(z) |
| 107 | + z = self.maxpool1(z) |
| 108 | + z = self.conv2(z) |
| 109 | + z = self.maxpool2(z) |
| 110 | + z = self.flatten(z) |
| 111 | + z = self.linear1(z) |
| 112 | + z = self.linear2(z) |
| 113 | + z = self.linear3(z) |
| 114 | + return z |
| 115 | + |
| 116 | + |
| 117 | +# get the network |
| 118 | +net = CNN() |
| 119 | + |
| 120 | +# training settings |
| 121 | +n_epoch = 500 |
| 122 | +learning_rate = 0.0001 |
| 123 | +print_freq = 5 |
| 124 | +n_step_epoch = int(len(y_train) / 128) |
| 125 | +n_step = n_epoch * n_step_epoch |
| 126 | +shuffle_buffer_size = 128 |
| 127 | +# Get training parameters |
| 128 | +train_weights = net.trainable_weights |
| 129 | +# Define the optimizer, use the Adam optimizer. |
| 130 | +optimizer = tlx.optimizers.Adam(learning_rate) |
| 131 | +# Define evaluation metrics. |
| 132 | +metrics = tlx.metrics.Accuracy() |
| 133 | + |
| 134 | +# Define the loss calculation process |
| 135 | +class WithLoss(Module): |
| 136 | + |
| 137 | + def __init__(self, net, loss_fn): |
| 138 | + super(WithLoss, self).__init__() |
| 139 | + self._net = net |
| 140 | + self._loss_fn = loss_fn |
| 141 | + |
| 142 | + def forward(self, data, label): |
| 143 | + out = self._net(data) |
| 144 | + loss = self._loss_fn(out, label) |
| 145 | + return loss |
| 146 | + |
| 147 | + |
| 148 | +net_with_loss = WithLoss(net.mlu(), loss_fn=tlx.losses.softmax_cross_entropy_with_logits).mlu() |
| 149 | +model = tlx.ops.distributed_model(net_with_loss, device_ids=[args.local_rank], |
| 150 | + output_device=args.local_rank, |
| 151 | + find_unused_parameters=True) |
| 152 | +# Initialize one-step training |
| 153 | +#net_with_train = TrainOneStep(net_with_loss, optimizer, train_weights) |
| 154 | +net_with_train = TrainOneStep(model, optimizer, train_weights) |
| 155 | + |
| 156 | +# Custom training loops |
| 157 | +for epoch in range(n_epoch): |
| 158 | + start_time = time.time() |
| 159 | + # Set the network to training state |
| 160 | + net.set_train() |
| 161 | + train_loss, train_acc, n_iter = 0, 0, 0 |
| 162 | + # Get training data and labels |
| 163 | + for X_batch, y_batch in train_dataset: |
| 164 | + # Calculate the loss value, and automatically complete the gradient update |
| 165 | + _loss_ce = net_with_train(X_batch.mlu(), y_batch.mlu()) |
| 166 | + train_loss += _loss_ce |
| 167 | + |
| 168 | + n_iter += 1 |
| 169 | + _logits = net(X_batch.mlu()) |
| 170 | + # Calculate accuracy |
| 171 | + metrics.update(_logits, y_batch.mlu()) |
| 172 | + train_acc += metrics.result() |
| 173 | + metrics.reset() |
| 174 | + if (n_iter % 100 == 0): |
| 175 | + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) |
| 176 | + print("rank {} train loss: {}".format(args.local_rank,train_loss / n_iter)) |
| 177 | + print("rank {} train acc: {}".format(args.local_rank,train_acc / n_iter)) |
| 178 | + |
0 commit comments