Skip to content

Commit abbfb05

Browse files
committed
Merge remote-tracking branch 'upstream/release/1.2-dev' into refactor/checkpoint
2 parents 87023eb + a9d9f33 commit abbfb05

File tree

228 files changed

+2955
-1441
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

228 files changed

+2955
-1441
lines changed

.github/workflows/ci_dockers.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
python_version: [3.6]
18-
pytorch_version: [1.3, 1.7]
18+
pytorch_version: [1.4, 1.7]
1919
steps:
2020
- name: Checkout
2121
uses: actions/checkout@v2
@@ -74,7 +74,7 @@ jobs:
7474
- python_version: 3.7
7575
pytorch_version: 1.6
7676
- python_version: 3.6
77-
pytorch_version: 1.3
77+
pytorch_version: 1.4
7878
steps:
7979
- name: Checkout
8080
uses: actions/checkout@v2

.github/workflows/ci_test-conda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
# os: [ubuntu-20.04]
1818
python-version: [3.7]
19-
pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8]
19+
pytorch-version: [1.4, 1.5, 1.6, 1.7, 1.8]
2020

2121
# Timeout: https://stackoverflow.com/a/59076067/4521646
2222
timeout-minutes: 35

.github/workflows/ci_test-full.yml

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,6 @@ jobs:
5252
open(fname, 'w').writelines(lines)
5353
shell: python
5454

55-
# versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996)
56-
- name: Adjust minimal for Python 3.8 and MacOS
57-
if: matrix.requires == 'minimal' && (runner.os == 'macOS' || matrix.python-version == 3.8)
58-
run : |
59-
fname = 'requirements.txt'
60-
req = open(fname).read().replace('torch>=1.3', 'torch>=1.4')
61-
open(fname, 'w').write(req)
62-
63-
fname = 'requirements/examples.txt'
64-
req = open(fname).read().replace('torchvision>=0.4.1', 'torchvision>=0.5.0')
65-
open(fname, 'w').write(req)
66-
67-
fname = 'requirements/extra.txt'
68-
req = open(fname).read().replace('torchtext>=0.3.1', 'torchtext>=0.5.0')
69-
open(fname, 'w').write(req)
70-
shell: python
71-
7255
- name: Set min. dependencies
7356
if: matrix.requires == 'minimal'
7457
run: |

.github/workflows/nightly.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,7 @@ jobs:
8585
fail-fast: false
8686
matrix:
8787
python_version: [3.6, 3.7, 3.8]
88-
pytorch_version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8]
89-
exclude:
90-
# excludes PT 1.3 as it is missing on pypi
91-
- python_version: 3.8
92-
pytorch_version: 1.3
88+
pytorch_version: [1.4, 1.5, 1.6, 1.7, 1.8]
9389

9490
steps:
9591
- name: Checkout

.github/workflows/release-docker.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@ jobs:
1414
fail-fast: false
1515
matrix:
1616
python_version: [3.6, 3.7, 3.8]
17-
pytorch_version: [1.3, 1.4, 1.5, 1.6, 1.7]
18-
exclude:
19-
# excludes PT 1.3 as it is missing on pypi
20-
- python_version: 3.8
21-
pytorch_version: 1.3
17+
pytorch_version: [1.4, 1.5, 1.6, 1.7]
2218
steps:
2319
- name: Checkout
2420
uses: actions/checkout@v2

CHANGELOG.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))
3131

32+
33+
- Added `LambdaCallback` ([#5347](https://github.com/PyTorchLightning/pytorch-lightning/pull/5347))
34+
35+
3236
- Added `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377))
3337

3438

@@ -47,6 +51,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4751
- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
4852

4953

54+
- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))
55+
56+
57+
- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))
58+
59+
60+
5061
### Changed
5162

5263
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
@@ -61,6 +72,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6172
- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
6273

6374

75+
- Metric `compute()` method will no longer automatically call `reset()` ([#5409](https://github.com/PyTorchLightning/pytorch-lightning/pull/5409/))
76+
77+
78+
- Set PyTorch 1.4 as min requirements, also for testing and examples `torchvision>=0.5` and `torchtext>=0.5` ([#5418](https://github.com/PyTorchLightning/pytorch-lightning/pull/5418))
79+
80+
81+
- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446))
82+
83+
84+
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))
85+
86+
6487
### Deprecated
6588

6689
- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

README.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,15 @@ Lightning can automatically export to ONNX or TorchScript for those cases.
8181
## Continuous Integration
8282
<center>
8383

84-
| System / PyTorch ver. | 1.3 (min. req.)* | 1.4 | 1.5 | 1.6 | 1.7 (latest) | 1.8 (nightly) |
85-
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
86-
| Conda py3.7 [linux] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
87-
| Linux py3.7 [GPUs**] | - | - | - | [![GPUs Status](http://104.154.220.231/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://104.154.220.231/PyTorchLightning/pytorch-lightning) | - | - |
88-
| Linux py3.{6,7} [TPUs***] | - | - | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - |
89-
| Linux py3.{6,7} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
90-
| OSX py3.{6,7,8} | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
91-
| Windows py3.{6,7,8} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
92-
93-
- _\* `torch>=1.4` is the minimal pytorch version for Python 3.8_
84+
| System / PyTorch ver. | 1.4 (min. req.)* | 1.5 | 1.6 | 1.7 (latest) | 1.8 (nightly) |
85+
| :---: | :---: | :---: | :---: | :---: | :---: |
86+
| Conda py3.7 [linux] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
87+
| Linux py3.7 [GPUs**] | - | - | [![GPUs Status](http://104.154.220.231/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://104.154.220.231/PyTorchLightning/pytorch-lightning) | - | - |
88+
| Linux py3.{6,7} [TPUs***] | - | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) |
89+
| Linux py3.{6,7} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
90+
| OSX py3.{6,7,8} | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
91+
| Windows py3.{6,7,8} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
92+
9493
- _\** tests run on two NVIDIA K80_
9594
- _\*** tests run on Google GKE TPUv2/3_
9695
- _TPU w/ py3.6/py3.7 means we support Colab and Kaggle env._

dockers/base-conda/Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.6 --build-arg PYTORCH_CHANNEL=pytorch
1818
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.5 --build-arg PYTORCH_CHANNEL=pytorch
1919
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.4 --build-arg PYTORCH_CHANNEL=pytorch
20-
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.3 --build-arg PYTORCH_CHANNEL=pytorch
2120

2221
ARG CUDNN_VERSION=8
2322
ARG CUDA_VERSION=10.2

dockers/base-cuda/Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.6 --build-arg CUDA_VERSION=10.2
1818
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.5 --build-arg CUDA_VERSION=10.2
1919
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.4 --build-arg CUDA_VERSION=10.1
20-
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.3 --build-arg CUDA_VERSION=10.1
2120

2221
ARG CUDNN_VERSION=8
2322
ARG CUDA_VERSION=10.2

docs/source/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Lightning has a few built-in callbacks.
9898
EarlyStopping
9999
GPUStatsMonitor
100100
GradientAccumulationScheduler
101+
LambdaCallback
101102
LearningRateMonitor
102103
ModelCheckpoint
103104
ProgressBar

docs/source/metrics.rst

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to t
2020
serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the
2121
provided input.
2222

23+
.. warning::
24+
From v1.2 onward ``compute()`` will no longer automatically call ``reset()``,
25+
and it is up to the user to reset metrics between epochs, except in the case where the
26+
metric is directly passed to ``LightningModule``s ``self.log``.
27+
2328
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
2429
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
2530
logic present in ``.compute()`` is applied to state information from all processes.
@@ -377,8 +382,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
377382
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
378383

379384

380-
Using the ``is_multiclass`` parameter
381-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
385+
Using the is_multiclass parameter
386+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
382387

383388
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
384389
but are actually binary/multi-label - for example, if both predictions and targets are
@@ -597,14 +602,14 @@ roc [func]
597602
precision [func]
598603
~~~~~~~~~~~~~~~~
599604

600-
.. autofunction:: pytorch_lightning.metrics.functional.classification.precision
605+
.. autofunction:: pytorch_lightning.metrics.functional.precision
601606
:noindex:
602607

603608

604609
precision_recall [func]
605610
~~~~~~~~~~~~~~~~~~~~~~~
606611

607-
.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall
612+
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
608613
:noindex:
609614

610615

@@ -618,7 +623,7 @@ precision_recall_curve [func]
618623
recall [func]
619624
~~~~~~~~~~~~~
620625

621-
.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
626+
.. autofunction:: pytorch_lightning.metrics.functional.recall
622627
:noindex:
623628

624629
select_topk [func]

docs/source/tpu.rst

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ To access TPUs, there are three main ways.
4040
----------------
4141

4242
Colab TPUs
43-
-----------
43+
----------
4444
Colab is like a jupyter notebook with a free GPU or TPU
4545
hosted on GCP.
4646

@@ -129,8 +129,7 @@ That's it! Your model will train on all 8 TPU cores.
129129
----------------
130130

131131
TPU core training
132-
133-
------------------------
132+
-----------------
134133

135134
Lightning supports training on a single TPU core or 8 TPU cores.
136135

@@ -177,7 +176,7 @@ on how to set up the instance groups and VMs needed to run TPU Pods.
177176
----------------
178177

179178
16 bit precision
180-
-----------------
179+
----------------
181180
Lightning also supports training in 16-bit precision with TPUs.
182181
By default, TPU training will use 32-bit precision. To enable 16-bit,
183182
set the 16-bit flag.
@@ -194,6 +193,28 @@ Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia
194193

195194
----------------
196195

196+
Performance considerations
197+
--------------------------
198+
199+
The TPU was designed for specific workloads and operations to carry out large volumes of matrix multiplication,
200+
convolution operations and other commonly used ops in applied deep learning.
201+
The specialization makes it a strong choice for NLP tasks, sequential convolutional networks, and under low precision operation.
202+
There are cases in which training on TPUs is slower when compared with GPUs, for possible reasons listed:
203+
204+
- Too small batch size.
205+
- Explicit evaluation of tensors during training, e.g. ``tensor.item()``
206+
- Tensor shapes (e.g. model inputs) change often during training.
207+
- Limited resources when using TPU's with PyTorch `Link <https://github.com/pytorch/xla/issues/2054#issuecomment-627367729>`_
208+
- XLA Graph compilation during the initial steps `Reference <https://github.com/pytorch/xla/issues/2383#issuecomment-666519998>`_
209+
- Some tensor ops are not fully supported on TPU, or not supported at all. These operations will be performed on CPU (context switch).
210+
- PyTorch integration is still experimental. Some performance bottlenecks may simply be the result of unfinished implementation.
211+
212+
The official PyTorch XLA `performance guide <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#known-performance-caveats>`_
213+
has more detailed information on how PyTorch code can be optimized for TPU. In particular, the
214+
`metrics report <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#get-a-metrics-report>`_ allows
215+
one to identify operations that lead to context switching.
216+
217+
197218
About XLA
198219
----------
199220
XLA is the library that interfaces PyTorch with the TPUs.

environment.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies:
2626
- python>=3.6
2727
- pip>20.1
2828
- numpy>=1.16.4
29-
- pytorch>=1.3,<1.8
29+
- pytorch>=1.4
3030
- future>=0.17.1
3131
- PyYAML>=5.1
3232
- tqdm>=4.41.0
@@ -38,10 +38,10 @@ dependencies:
3838
- scikit-learn>=0.20.0
3939
- matplotlib>=3.1.1
4040
- omegaconf>=2.0.0
41-
- torchtext>=0.3.1
41+
- torchtext>=0.5
4242

4343
# Examples
44-
- torchvision>=0.4.1,<0.9.0
44+
- torchvision>=0.5
4545

4646
- pip:
4747
- test-tube>=0.7.5

pyproject.toml

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,16 @@ known_first_party = [
2323
"tests",
2424
]
2525
skip_glob = [
26-
"pytorch_lightning/accelerators/*",
27-
"pytorch_lightning/callbacks/*",
28-
"pytorch_lightning/cluster_environments/*",
26+
# todo
2927
"pytorch_lightning/core/*",
28+
29+
30+
# todo
3031
"pytorch_lightning/distributed/*",
31-
"pytorch_lightning/loggers/*",
32-
"pytorch_lightning/metrics/*",
33-
"pytorch_lightning/overrides/*",
32+
33+
34+
# todo
3435
"pytorch_lightning/plugins/*",
35-
"pytorch_lightning/profiler/*",
36-
"pytorch_lightning/trainer/*",
37-
"pytorch_lightning/tuner/*",
38-
"pytorch_lightning/utilities/*",
39-
"tests/backends/*",
40-
"tests/base/*",
41-
"tests/callbacks/*",
42-
"tests/checkpointing/*",
43-
"tests/core/*",
44-
"tests/loggers/*",
45-
"tests/metrics/*",
46-
"tests/models/*",
47-
"tests/plugins/*",
48-
"tests/trainer/*",
49-
"tests/tuner/*",
50-
"tests/utilities/*",
5136
]
5237
profile = "black"
5338
line_length = 120

0 commit comments

Comments
 (0)