Skip to content

Commit 39d6bc9

Browse files
authored
Configure pylint and fix linter violations (#179)
See https://pylint.readthedocs.io/en/stable/ This PR also updates the project to adopt [GitHub's "Scripts to Rule Them All" convention](https://github.com/github/scripts-to-rule-them-all). --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 153c4fe commit 39d6bc9

18 files changed

+182
-53
lines changed

.github/workflows/ci.yaml

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@ jobs:
2828
with:
2929
python-version: ${{ matrix.python-version }}
3030
cache: "pip"
31-
- name: Install dependencies
32-
run: |
33-
python -m pip install -r requirements.txt -r requirements-dev.txt .
34-
yes | python -m mypy --install-types replicate || true
3531

36-
- name: Lint
37-
run: |
38-
python -m mypy replicate
39-
python -m ruff .
40-
python -m ruff format --check .
32+
- name: Setup
33+
run: ./script/setup
34+
4135
- name: Test
42-
run: python -m pytest
36+
run: ./script/test
37+
38+
- name: Lint
39+
run: ./script/lint
40+

pyproject.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ requires-python = ">=3.8"
1313
dependencies = ["packaging", "pydantic>1", "httpx>=0.21.0,<1"]
1414
optional-dependencies = { dev = [
1515
"mypy",
16+
"pylint",
1617
"pytest",
1718
"pytest-asyncio",
1819
"pytest-recording",
@@ -27,13 +28,27 @@ repository = "https://github.com/replicate/replicate-python"
2728
[tool.pytest.ini_options]
2829
testpaths = "tests/"
2930

31+
[tool.setuptools]
32+
# See https://github.com/pypa/setuptools/issues/3197#issuecomment-1078770109
33+
py-modules = []
34+
3035
[tool.setuptools.package-data]
3136
"replicate" = ["py.typed"]
3237

3338
[tool.mypy]
3439
plugins = "pydantic.mypy"
3540
exclude = ["tests/"]
3641

42+
[tool.pylint.main]
43+
disable = [
44+
"C0301", # Line too long
45+
"C0413", # Import should be placed at the top of the module
46+
"C0114", # Missing module docstring
47+
"R0801", # Similar lines in N files
48+
"W0212", # Access to a protected member
49+
"W0622", # Redefining built-in
50+
]
51+
3752
[tool.ruff]
3853
select = [
3954
"E", # pycodestyle error

replicate/base_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def reload(self) -> None:
2424
"""
2525
Load this object from the server again.
2626
"""
27-
new_model = self._collection.get(self.id)
28-
for k, v in new_model.dict().items():
27+
28+
new_model = self._collection.get(self.id) # pylint: disable=no-member
29+
for k, v in new_model.dict().items(): # pylint: disable=invalid-name
2930
setattr(self, k, v)

replicate/client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,30 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
8484

8585
@property
8686
def models(self) -> ModelCollection:
87+
"""
88+
Namespace for operations related to models.
89+
"""
8790
return ModelCollection(client=self)
8891

8992
@property
9093
def predictions(self) -> PredictionCollection:
94+
"""
95+
Namespace for operations related to predictions.
96+
"""
9197
return PredictionCollection(client=self)
9298

9399
@property
94100
def trainings(self) -> TrainingCollection:
101+
"""
102+
Namespace for operations related to trainings.
103+
"""
95104
return TrainingCollection(client=self)
96105

97106
@property
98107
def deployments(self) -> DeploymentCollection:
108+
"""
109+
Namespace for operations related to deployments.
110+
"""
99111
return DeploymentCollection(client=self)
100112

101113
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401

replicate/collection.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from replicate.client import Client
66

77
from replicate.base_model import BaseModel
8+
from replicate.exceptions import ReplicateException
89

910
Model = TypeVar("Model", bound=BaseModel)
1011

@@ -17,20 +18,21 @@ class Collection(abc.ABC, Generic[Model]):
1718
def __init__(self, client: "Client") -> None:
1819
self._client = client
1920

20-
@abc.abstractproperty
21-
def model(self) -> Model:
21+
@property
22+
@abc.abstractmethod
23+
def model(self) -> Model: # pylint: disable=missing-function-docstring
2224
pass
2325

2426
@abc.abstractmethod
25-
def list(self) -> List[Model]:
27+
def list(self) -> List[Model]: # pylint: disable=missing-function-docstring
2628
pass
2729

2830
@abc.abstractmethod
29-
def get(self, key: str) -> Model:
31+
def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring
3032
pass
3133

3234
@abc.abstractmethod
33-
def create(self, **kwargs) -> Model:
35+
def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring
3436
pass
3537

3638
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
@@ -41,13 +43,12 @@ def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
4143
attrs._client = self._client
4244
attrs._collection = self
4345
return cast(Model, attrs)
44-
elif (
45-
isinstance(attrs, dict) and self.model is not None and callable(self.model)
46-
):
46+
47+
if isinstance(attrs, dict) and self.model is not None and callable(self.model):
4748
model = self.model(**attrs)
4849
model._client = self._client
4950
model._collection = self
5051
return model
51-
else:
52-
name = self.model.__name__ if hasattr(self.model, "__name__") else "model"
53-
raise Exception(f"Can't create {name} from {attrs}")
52+
53+
name = self.model.__name__ if hasattr(self.model, "__name__") else "model"
54+
raise ReplicateException(f"Can't create {name} from {attrs}")

replicate/deployment.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,19 @@ def predictions(self) -> "DeploymentPredictionCollection":
3535

3636

3737
class DeploymentCollection(Collection):
38+
"""
39+
Namespace for operations related to deployments.
40+
"""
41+
3842
model = Deployment
3943

4044
def list(self) -> List[Deployment]:
45+
"""
46+
List deployments.
47+
48+
Raises:
49+
NotImplementedError: This method is not implemented.
50+
"""
4151
raise NotImplementedError()
4252

4353
def get(self, name: str) -> Deployment:
@@ -56,6 +66,12 @@ def get(self, name: str) -> Deployment:
5666
return self.prepare_model({"username": username, "name": name})
5767

5868
def create(self, **kwargs) -> Deployment:
69+
"""
70+
Create a deployment.
71+
72+
Raises:
73+
NotImplementedError: This method is not implemented.
74+
"""
5975
raise NotImplementedError()
6076

6177
def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
@@ -74,6 +90,12 @@ def __init__(self, client: "Client", deployment: Deployment) -> None:
7490
self._deployment = deployment
7591

7692
def list(self) -> List[Prediction]:
93+
"""
94+
List predictions in a deployment.
95+
96+
Raises:
97+
NotImplementedError: This method is not implemented.
98+
"""
7799
raise NotImplementedError()
78100

79101
def get(self, id: str) -> Prediction:

replicate/files.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,34 @@
77
import httpx
88

99

10-
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
10+
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
1111
"""
1212
Upload a file to the server.
1313
1414
Args:
15-
fh: A file handle to upload.
15+
file: A file handle to upload.
1616
output_file_prefix: A string to prepend to the output file name.
1717
Returns:
1818
str: A URL to the uploaded file.
1919
"""
2020
# Lifted straight from cog.files
2121

22-
fh.seek(0)
22+
file.seek(0)
2323

2424
if output_file_prefix is not None:
25-
name = getattr(fh, "name", "output")
25+
name = getattr(file, "name", "output")
2626
url = output_file_prefix + os.path.basename(name)
27-
resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore
27+
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
2828
resp.raise_for_status()
29+
2930
return url
3031

31-
b = fh.read()
32-
# The file handle is strings, not bytes
33-
if isinstance(b, str):
34-
b = b.encode("utf-8")
35-
encoded_body = base64.b64encode(b)
36-
if getattr(fh, "name", None):
37-
# despite doing a getattr check here, mypy complains that io.IOBase has no attribute name
38-
mime_type = mimetypes.guess_type(fh.name)[0] # type: ignore
39-
else:
40-
mime_type = "application/octet-stream"
41-
s = encoded_body.decode("utf-8")
42-
return f"data:{mime_type};base64,{s}"
32+
body = file.read()
33+
# Ensure the file handle is in bytes
34+
body = body.encode("utf-8") if isinstance(body, str) else body
35+
encoded_body = base64.b64encode(body).decode("utf-8")
36+
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
37+
mime_type = (
38+
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
39+
)
40+
return f"data:{mime_type};base64,{encoded_body}"

replicate/json.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
try:
77
import numpy as np # type: ignore
88

9-
has_numpy = True
9+
HAS_NUMPY = True
1010
except ImportError:
11-
has_numpy = False
11+
HAS_NUMPY = False
1212

1313

14+
# pylint: disable=too-many-return-statements
1415
def encode_json(
1516
obj: Any, # noqa: ANN401
1617
upload_file: Callable[[io.IOBase], str],
@@ -25,11 +26,11 @@ def encode_json(
2526
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
2627
return [encode_json(value, upload_file) for value in obj]
2728
if isinstance(obj, Path):
28-
with obj.open("rb") as f:
29-
return upload_file(f)
29+
with obj.open("rb") as file:
30+
return upload_file(file)
3031
if isinstance(obj, io.IOBase):
3132
return upload_file(obj)
32-
if has_numpy:
33+
if HAS_NUMPY:
3334
if isinstance(obj, np.integer):
3435
return int(obj)
3536
if isinstance(obj, np.floating):

replicate/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def versions(self) -> VersionCollection:
107107

108108

109109
class ModelCollection(Collection):
110+
"""
111+
Namespace for operations related to models.
112+
"""
113+
110114
model = Model
111115

112116
def list(self) -> List[Model]:
@@ -136,6 +140,12 @@ def get(self, key: str) -> Model:
136140
return self.prepare_model(resp.json())
137141

138142
def create(self, **kwargs) -> Model:
143+
"""
144+
Create a model.
145+
146+
Raises:
147+
NotImplementedError: This method is not implemented.
148+
"""
139149
raise NotImplementedError()
140150

141151
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:

replicate/prediction.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def wait(self) -> None:
103103
Wait for prediction to finish.
104104
"""
105105
while self.status not in ["succeeded", "failed", "canceled"]:
106-
time.sleep(self._client.poll_interval)
106+
time.sleep(self._client.poll_interval) # pylint: disable=no-member
107107
self.reload()
108108

109109
def output_iterator(self) -> Iterator[Any]:
@@ -114,7 +114,7 @@ def output_iterator(self) -> Iterator[Any]:
114114
new_output = output[len(previous_output) :]
115115
yield from new_output
116116
previous_output = output
117-
time.sleep(self._client.poll_interval)
117+
time.sleep(self._client.poll_interval) # pylint: disable=no-member
118118
self.reload()
119119

120120
if self.status == "failed":
@@ -129,10 +129,14 @@ def cancel(self) -> None:
129129
"""
130130
Cancels a running prediction.
131131
"""
132-
self._client._request("POST", f"/v1/predictions/{self.id}/cancel")
132+
self._client._request("POST", f"/v1/predictions/{self.id}/cancel") # pylint: disable=no-member
133133

134134

135135
class PredictionCollection(Collection):
136+
"""
137+
Namespace for operations related to predictions.
138+
"""
139+
136140
model = Prediction
137141

138142
def list(self) -> List[Prediction]:

replicate/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ def version_has_no_array_type(cog_version: str) -> Optional[bool]:
1515

1616
def make_schema_backwards_compatible(
1717
schema: dict,
18-
version: str,
18+
cog_version: str,
1919
) -> dict:
2020
"""A place to add backwards compatibility logic for our openapi schema"""
2121

2222
# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
23-
if version_has_no_array_type(version):
23+
if version_has_no_array_type(cog_version):
2424
output = schema["components"]["schemas"]["Output"]
2525
if output.get("type") == "array":
2626
output["x-cog-array-type"] = "iterator"

replicate/training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,14 @@ class Training(BaseModel):
5858

5959
def cancel(self) -> None:
6060
"""Cancel a running training"""
61-
self._client._request("POST", f"/v1/trainings/{self.id}/cancel")
61+
self._client._request("POST", f"/v1/trainings/{self.id}/cancel") # pylint: disable=no-member
6262

6363

6464
class TrainingCollection(Collection):
65+
"""
66+
Namespace for operations related to trainings.
67+
"""
68+
6569
model = Training
6670

6771
def list(self) -> List[Training]:

0 commit comments

Comments
 (0)