Skip to content

feat: support dataloader with multiple inputs #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add :class:`SoftGradientBoostingClassifier` and :class:`SoftGradientBoostingRegressor` | `@xuyxu <https://github.com/xuyxu>`__
* |Feature| |API| Support using dataloader with multiple input | `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Fix missing functionality of ``use_reduction_sum`` for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
* |Enhancement| Relax :mod:`tensorboard` as a soft dependency | `@xuyxu <https://github.com/xuyxu>`__
* |Enhancement| |API| Simplify the training workflow of :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
4 changes: 0 additions & 4 deletions torchensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .adversarial_training import AdversarialTrainingRegressor
from .fast_geometric import FastGeometricClassifier
from .fast_geometric import FastGeometricRegressor
from .soft_gradient_boosting import SoftGradientBoostingClassifier
from .soft_gradient_boosting import SoftGradientBoostingRegressor


__all__ = [
Expand All @@ -31,6 +29,4 @@
"AdversarialTrainingRegressor",
"FastGeometricClassifier",
"FastGeometricRegressor",
"SoftGradientBoostingClassifier",
"SoftGradientBoostingRegressor",
]
58 changes: 30 additions & 28 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from . import _constants as const
from .utils.io import split_data_target
from .utils.logging import get_tb_logger


Expand Down Expand Up @@ -170,27 +171,26 @@ def fit(
"""

@torch.no_grad()
def predict(self, X, return_numpy=True):
"""Docstrings decorated by downstream models."""
def predict(self, *x):
"""Docstrings decorated by downstream ensembles."""
self.eval()
pred = None

if isinstance(X, torch.Tensor):
pred = self.forward(X.to(self.device))
elif isinstance(X, np.ndarray):
X = torch.Tensor(X).to(self.device)
pred = self.forward(X)
else:
msg = (
"The type of input X should be one of {{torch.Tensor,"
" np.ndarray}}."
)
raise ValueError(msg)
# Copy data
x_device = []
for data in x:
if isinstance(data, torch.Tensor):
x_device.append(data.to(self.device))
elif isinstance(data, np.ndarray):
x_device.append(torch.Tensor(data).to(self.device))
else:
msg = (
"The type of input X should be one of {{torch.Tensor,"
" np.ndarray}}."
)
raise ValueError(msg)

pred = self.forward(*x_device)
pred = pred.cpu()
if return_numpy:
return pred.numpy()

return pred


Expand All @@ -212,7 +212,8 @@ def _decide_n_outputs(self, train_loader):
# Infer `n_outputs` from the dataloader
else:
labels = []
for _, (_, target) in enumerate(train_loader):
for _, elem in enumerate(train_loader):
_, target = split_data_target(elem, self.device)
labels.append(target)
labels = torch.unique(torch.cat(labels))
n_outputs = labels.size(0)
Expand All @@ -228,9 +229,9 @@ def evaluate(self, test_loader, return_loss=False):
criterion = nn.CrossEntropyLoss()
loss = 0.0

for _, (data, target) in enumerate(test_loader):
data, target = data.to(self.device), target.to(self.device)
output = self.forward(data)
for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)
output = self.forward(*data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
Expand Down Expand Up @@ -258,25 +259,26 @@ def _decide_n_outputs(self, train_loader):
The number of outputs equals the number of target variables for
regressors (e.g., `1` in univariate regression).
"""
for _, (_, target) in enumerate(train_loader):
for _, elem in enumerate(train_loader):
_, target = split_data_target(elem, self.device)
if len(target.size()) == 1:
n_outputs = 1
n_outputs = 1 # univariate regression
else:
n_outputs = target.size(1)
n_outputs = target.size(1) # multivariate regression
break

return n_outputs

@torch.no_grad()
def evaluate(self, test_loader):
"""Docstrings decorated by downstream models."""
"""Docstrings decorated by downstream ensembles."""
self.eval()
mse = 0.0
criterion = nn.MSELoss()

for _, (data, target) in enumerate(test_loader):
data, target = data.to(self.device), target.to(self.device)
output = self.forward(data)
for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)
output = self.forward(*data)
mse += criterion(output, target)

return float(mse) / len(test_loader)
7 changes: 1 addition & 6 deletions torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,13 @@
Parameters
----------
X : {Tensor, ndarray}
A data batch in the form of tensor or Numpy array.
return_numpy : bool, default=True
Whether to convert the predictions into a Numpy array.
A data batch in the form of tensor or numpy array.

Returns
-------
pred : Array of shape (n_samples, n_outputs)
For classifiers, ``n_outputs`` is the number of distinct classes. For
regressors, ``n_output`` is the number of target variables.

- If ``return_numpy`` is ``False``, the result is a tensor.
- If ``return_numpy`` is ``True``, the result is a Numpy array.
"""


Expand Down
91 changes: 49 additions & 42 deletions torchensemble/adversarial_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,24 @@ def _parallel_fit_per_epoch(
# Parallelization corrupts the binding between optimizer and scheduler
set_module.update_lr(optimizer, cur_lr)

for batch_idx, (data, target) in enumerate(train_loader):
for batch_idx, elem in enumerate(train_loader):

batch_size = data.size()[0]
data, target = data.to(device), target.to(device)
data.requires_grad = True
data, target = io.split_data_target(elem, device)
batch_size = data[0].size(0)
for tensor in data:
tensor.requires_grad = True

# Get adversarial samples
_output = estimator(data)
_output = estimator(*data)
_loss = criterion(_output, target)
_loss.backward()
data_grad = data.grad.data
data_grad = [tensor.grad.data for tensor in data]
adv_data = _get_fgsm_samples(data, epsilon, data_grad)

# Compute the training loss
optimizer.zero_grad()
org_output = estimator(data)
adv_output = estimator(adv_data)
org_output = estimator(*data)
adv_output = estimator(*adv_data)
loss = criterion(org_output, target) + criterion(adv_output, target)
loss.backward()
optimizer.step()
Expand Down Expand Up @@ -156,27 +157,31 @@ def _parallel_fit_per_epoch(
return estimator, optimizer


def _get_fgsm_samples(sample, epsilon, sample_grad):
def _get_fgsm_samples(sample_list, epsilon, sample_grad_list):
"""
Private functions used to generate adversarial samples with fast gradient
sign method (FGSM).
"""

# Check the input range of `sample`
min_value, max_value = torch.min(sample), torch.max(sample)
if not 0 <= min_value < max_value <= 1:
msg = (
"The input range of samples passed to adversarial training"
" should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
" instead."
)
raise ValueError(msg.format(min_value, max_value))
perturbed_sample_list = []
for sample, sample_grad in zip(sample_list, sample_grad_list):
# Check the input range of `sample`
min_value, max_value = torch.min(sample), torch.max(sample)
if not 0 <= min_value < max_value <= 1:
msg = (
"The input range of samples passed to adversarial training"
" should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
" instead."
)
raise ValueError(msg.format(min_value, max_value))

sign_sample_grad = sample_grad.sign()
perturbed_sample = sample + epsilon * sign_sample_grad
perturbed_sample = torch.clamp(perturbed_sample, 0, 1)
sign_sample_grad = sample_grad.sign()
perturbed_sample = sample + epsilon * sign_sample_grad
perturbed_sample = torch.clamp(perturbed_sample, 0, 1)

return perturbed_sample
perturbed_sample_list.append(perturbed_sample)

return perturbed_sample_list


class _BaseAdversarialTraining(BaseModule):
Expand Down Expand Up @@ -218,10 +223,10 @@ class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier):
"""Implementation on the data forwarding in AdversarialTrainingClassifier.""", # noqa: E501
"classifier_forward",
)
def forward(self, x):
def forward(self, *x):
# Take the average over class distributions from all base estimators.
outputs = [
F.softmax(estimator(x), dim=1) for estimator in self.estimators_
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]
proba = op.average(outputs)

Expand Down Expand Up @@ -282,9 +287,9 @@ def fit(
best_acc = 0.0

# Internal helper function on pesudo forward
def _forward(estimators, data):
def _forward(estimators, *x):
outputs = [
F.softmax(estimator(data), dim=1) for estimator in estimators
F.softmax(estimator(*x), dim=1) for estimator in estimators
]
proba = op.average(outputs)

Expand Down Expand Up @@ -336,10 +341,11 @@ def _forward(estimators, data):
with torch.no_grad():
correct = 0
total = 0
for _, (data, target) in enumerate(test_loader):
data = data.to(self.device)
target = target.to(self.device)
output = _forward(estimators, data)
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(
elem, self.device
)
output = _forward(estimators, *data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
Expand Down Expand Up @@ -384,8 +390,8 @@ def evaluate(self, test_loader, return_loss=False):
return super().evaluate(test_loader, return_loss)

@torchensemble_model_doc(item="predict")
def predict(self, X, return_numpy=True):
return super().predict(X, return_numpy)
def predict(self, *x):
return super().predict(*x)


@torchensemble_model_doc(
Expand All @@ -397,9 +403,9 @@ class AdversarialTrainingRegressor(_BaseAdversarialTraining, BaseRegressor):
"""Implementation on the data forwarding in AdversarialTrainingRegressor.""", # noqa: E501
"regressor_forward",
)
def forward(self, x):
def forward(self, *x):
# Take the average over predictions from all base estimators.
outputs = [estimator(x) for estimator in self.estimators_]
outputs = [estimator(*x) for estimator in self.estimators_]
pred = op.average(outputs)

return pred
Expand Down Expand Up @@ -459,8 +465,8 @@ def fit(
best_mse = float("inf")

# Internal helper function on pesudo forward
def _forward(estimators, data):
outputs = [estimator(data) for estimator in estimators]
def _forward(estimators, *x):
outputs = [estimator(*x) for estimator in estimators]
pred = op.average(outputs)

return pred
Expand Down Expand Up @@ -510,10 +516,11 @@ def _forward(estimators, data):
self.eval()
with torch.no_grad():
mse = 0.0
for _, (data, target) in enumerate(test_loader):
data = data.to(self.device)
target = target.to(self.device)
output = _forward(estimators, data)
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(
elem, self.device
)
output = _forward(estimators, *data)
mse += criterion(output, target)
mse /= len(test_loader)

Expand Down Expand Up @@ -553,5 +560,5 @@ def evaluate(self, test_loader):
return super().evaluate(test_loader)

@torchensemble_model_doc(item="predict")
def predict(self, X, return_numpy=True):
return super().predict(X, return_numpy)
def predict(self, *x):
return super().predict(*x)
Loading