Skip to content

Commit b8dd97b

Browse files
authored
Add additional ruff suggestions (#1062)
* Enabled ruff rule PT001 and ANN204 * Enabled ruff rule B008 * Enabled ruff rule EM101 * Enabled ruff rule PLR1714 * Enabled ruff rule ANN201 * Enabled ruff rule C400 * Enabled ruff rule B904 * Enabled ruff rule UP006 * Enabled ruff rule RUF012 * Enabled ruff rule FBT003 * Enabled ruff rule C416 * Enabled ruff rule SIM102 * Enabled ruff rule PGH003 * Enabled ruff rule PERF401 * Enabled ruff rule EM102 * Enabled ruff rule SIM108 * Enabled ruff rule ICN001 * Enabled ruff rule ICN001 * implemented reviews * Update pyproject.toml to ignore `SIM102` * Enabled ruff rule PLW2901 * Enabled ruff rule RET503 * Fixed failing ruff tests
1 parent 7c1c08f commit b8dd97b

25 files changed

+213
-247
lines changed

benchmarks/db-benchmark/groupby-datafusion.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import timeit
2121

2222
import datafusion as df
23-
import pyarrow
23+
import pyarrow as pa
2424
from datafusion import (
2525
RuntimeEnvBuilder,
2626
SessionConfig,
@@ -37,7 +37,7 @@
3737
exec(open("./_helpers/helpers.py").read())
3838

3939

40-
def ans_shape(batches):
40+
def ans_shape(batches) -> tuple[int, int]:
4141
rows, cols = 0, 0
4242
for batch in batches:
4343
rows += batch.num_rows
@@ -48,7 +48,7 @@ def ans_shape(batches):
4848
return rows, cols
4949

5050

51-
def execute(df):
51+
def execute(df) -> list:
5252
print(df.execution_plan().display_indent())
5353
return df.collect()
5454

@@ -68,14 +68,14 @@ def execute(df):
6868
src_grp = os.path.join("data", data_name + ".csv")
6969
print("loading dataset %s" % src_grp, flush=True)
7070

71-
schema = pyarrow.schema(
71+
schema = pa.schema(
7272
[
73-
("id4", pyarrow.int32()),
74-
("id5", pyarrow.int32()),
75-
("id6", pyarrow.int32()),
76-
("v1", pyarrow.int32()),
77-
("v2", pyarrow.int32()),
78-
("v3", pyarrow.float64()),
73+
("id4", pa.int32()),
74+
("id5", pa.int32()),
75+
("id6", pa.int32()),
76+
("v1", pa.int32()),
77+
("v2", pa.int32()),
78+
("v3", pa.float64()),
7979
]
8080
)
8181

@@ -93,8 +93,8 @@ def execute(df):
9393
)
9494
config = (
9595
SessionConfig()
96-
.with_repartition_joins(False)
97-
.with_repartition_aggregations(False)
96+
.with_repartition_joins(enabled=False)
97+
.with_repartition_aggregations(enabled=False)
9898
.set("datafusion.execution.coalesce_batches", "false")
9999
)
100100
ctx = SessionContext(config, runtime)

benchmarks/db-benchmark/join-datafusion.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
exec(open("./_helpers/helpers.py").read())
3030

3131

32-
def ans_shape(batches):
32+
def ans_shape(batches) -> tuple[int, int]:
3333
rows, cols = 0, 0
3434
for batch in batches:
3535
rows += batch.num_rows
@@ -57,7 +57,8 @@ def ans_shape(batches):
5757
os.path.join("data", y_data_name[2] + ".csv"),
5858
]
5959
if len(src_jn_y) != 3:
60-
raise Exception("Something went wrong in preparing files used for join")
60+
error_msg = "Something went wrong in preparing files used for join"
61+
raise Exception(error_msg)
6162

6263
print(
6364
"loading datasets "

benchmarks/tpch/tpch.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from datafusion import SessionContext
2222

2323

24-
def bench(data_path, query_path):
24+
def bench(data_path, query_path) -> None:
2525
with open("results.csv", "w") as results:
2626
# register tables
2727
start = time.time()
@@ -68,10 +68,7 @@ def bench(data_path, query_path):
6868
with open(f"{query_path}/q{query}.sql") as f:
6969
text = f.read()
7070
tmp = text.split(";")
71-
queries = []
72-
for str in tmp:
73-
if len(str.strip()) > 0:
74-
queries.append(str.strip())
71+
queries = [s.strip() for s in tmp if len(s.strip()) > 0]
7572

7673
try:
7774
start = time.time()

dev/release/generate-changelog.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from github import Github
2525

2626

27-
def print_pulls(repo_name, title, pulls):
27+
def print_pulls(repo_name, title, pulls) -> None:
2828
if len(pulls) > 0:
2929
print(f"**{title}:**")
3030
print()
@@ -34,7 +34,7 @@ def print_pulls(repo_name, title, pulls):
3434
print()
3535

3636

37-
def generate_changelog(repo, repo_name, tag1, tag2, version):
37+
def generate_changelog(repo, repo_name, tag1, tag2, version) -> None:
3838
# get a list of commits between two tags
3939
print(f"Fetching list of commits between {tag1} and {tag2}", file=sys.stderr)
4040
comparison = repo.compare(tag1, tag2)
@@ -154,7 +154,7 @@ def generate_changelog(repo, repo_name, tag1, tag2, version):
154154
)
155155

156156

157-
def cli(args=None):
157+
def cli(args=None) -> None:
158158
"""Process command line arguments."""
159159
if not args:
160160
args = sys.argv[1:]

docs/source/conf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
autoapi_python_class_content = "both"
7474

7575

76-
def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa: ARG001
76+
def autoapi_skip_member_fn(app, what, name, obj, skip, options) -> bool: # noqa: ARG001
7777
skip_contents = [
7878
# Re-exports
7979
("class", "datafusion.DataFrame"),
@@ -93,7 +93,7 @@ def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa: ARG001
9393
return skip
9494

9595

96-
def setup(sphinx):
96+
def setup(sphinx) -> None:
9797
sphinx.connect("autoapi-skip-member", autoapi_skip_member_fn)
9898

9999

examples/create-context.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
runtime = RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000)
2626
config = (
2727
SessionConfig()
28-
.with_create_default_catalog_and_schema(True)
28+
.with_create_default_catalog_and_schema(enabled=True)
2929
.with_default_catalog_and_schema("foo", "bar")
3030
.with_target_partitions(8)
31-
.with_information_schema(True)
32-
.with_repartition_joins(False)
33-
.with_repartition_aggregations(False)
34-
.with_repartition_windows(False)
35-
.with_parquet_pruning(False)
31+
.with_information_schema(enabled=True)
32+
.with_repartition_joins(enabled=False)
33+
.with_repartition_aggregations(enabled=False)
34+
.with_repartition_windows(enabled=False)
35+
.with_parquet_pruning(enabled=False)
3636
.set("datafusion.execution.parquet.pushdown_filters", "true")
3737
)
3838
ctx = SessionContext(config, runtime)

examples/python-udaf.py

+16-20
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
import datafusion
19-
import pyarrow
19+
import pyarrow as pa
2020
import pyarrow.compute
2121
from datafusion import Accumulator, col, udaf
2222

@@ -26,48 +26,44 @@ class MyAccumulator(Accumulator):
2626
Interface of a user-defined accumulation.
2727
"""
2828

29-
def __init__(self):
30-
self._sum = pyarrow.scalar(0.0)
29+
def __init__(self) -> None:
30+
self._sum = pa.scalar(0.0)
3131

32-
def update(self, values: pyarrow.Array) -> None:
32+
def update(self, values: pa.Array) -> None:
3333
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
34-
self._sum = pyarrow.scalar(
35-
self._sum.as_py() + pyarrow.compute.sum(values).as_py()
36-
)
34+
self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(values).as_py())
3735

38-
def merge(self, states: pyarrow.Array) -> None:
36+
def merge(self, states: pa.Array) -> None:
3937
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
40-
self._sum = pyarrow.scalar(
41-
self._sum.as_py() + pyarrow.compute.sum(states).as_py()
42-
)
38+
self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states).as_py())
4339

44-
def state(self) -> pyarrow.Array:
45-
return pyarrow.array([self._sum.as_py()])
40+
def state(self) -> pa.Array:
41+
return pa.array([self._sum.as_py()])
4642

47-
def evaluate(self) -> pyarrow.Scalar:
43+
def evaluate(self) -> pa.Scalar:
4844
return self._sum
4945

5046

5147
# create a context
5248
ctx = datafusion.SessionContext()
5349

5450
# create a RecordBatch and a new DataFrame from it
55-
batch = pyarrow.RecordBatch.from_arrays(
56-
[pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
51+
batch = pa.RecordBatch.from_arrays(
52+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
5753
names=["a", "b"],
5854
)
5955
df = ctx.create_dataframe([[batch]])
6056

6157
my_udaf = udaf(
6258
MyAccumulator,
63-
pyarrow.float64(),
64-
pyarrow.float64(),
65-
[pyarrow.float64()],
59+
pa.float64(),
60+
pa.float64(),
61+
[pa.float64()],
6662
"stable",
6763
)
6864

6965
df = df.aggregate([], [my_udaf(col("a"))])
7066

7167
result = df.collect()[0]
7268

73-
assert result.column(0) == pyarrow.array([6.0])
69+
assert result.column(0) == pa.array([6.0])

examples/python-udf-comparisons.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def is_of_interest_impl(
112112
returnflag_arr: pa.Array,
113113
) -> pa.Array:
114114
result = []
115-
for idx, partkey in enumerate(partkey_arr):
116-
partkey = partkey.as_py()
115+
for idx, partkey_val in enumerate(partkey_arr):
116+
partkey = partkey_val.as_py()
117117
suppkey = suppkey_arr[idx].as_py()
118118
returnflag = returnflag_arr[idx].as_py()
119119
value = (partkey, suppkey, returnflag)
@@ -162,10 +162,7 @@ def udf_using_pyarrow_compute_impl(
162162
resultant_arr = pc.and_(filtered_partkey_arr, filtered_suppkey_arr)
163163
resultant_arr = pc.and_(resultant_arr, filtered_returnflag_arr)
164164

165-
if results is None:
166-
results = resultant_arr
167-
else:
168-
results = pc.or_(results, resultant_arr)
165+
results = resultant_arr if results is None else pc.or_(results, resultant_arr)
169166

170167
return results
171168

examples/python-udf.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,23 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import pyarrow
18+
import pyarrow as pa
1919
from datafusion import SessionContext, udf
2020
from datafusion import functions as f
2121

2222

23-
def is_null(array: pyarrow.Array) -> pyarrow.Array:
23+
def is_null(array: pa.Array) -> pa.Array:
2424
return array.is_null()
2525

2626

27-
is_null_arr = udf(is_null, [pyarrow.int64()], pyarrow.bool_(), "stable")
27+
is_null_arr = udf(is_null, [pa.int64()], pa.bool_(), "stable")
2828

2929
# create a context
3030
ctx = SessionContext()
3131

3232
# create a RecordBatch and a new DataFrame from it
33-
batch = pyarrow.RecordBatch.from_arrays(
34-
[pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
33+
batch = pa.RecordBatch.from_arrays(
34+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
3535
names=["a", "b"],
3636
)
3737
df = ctx.create_dataframe([[batch]])
@@ -40,4 +40,4 @@ def is_null(array: pyarrow.Array) -> pyarrow.Array:
4040

4141
result = df.collect()[0]
4242

43-
assert result.column(0) == pyarrow.array([False] * 3)
43+
assert result.column(0) == pa.array([False] * 3)

examples/query-pyarrow-data.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
# under the License.
1717

1818
import datafusion
19-
import pyarrow
19+
import pyarrow as pa
2020
from datafusion import col
2121

2222
# create a context
2323
ctx = datafusion.SessionContext()
2424

2525
# create a RecordBatch and a new DataFrame from it
26-
batch = pyarrow.RecordBatch.from_arrays(
27-
[pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
26+
batch = pa.RecordBatch.from_arrays(
27+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
2828
names=["a", "b"],
2929
)
3030
df = ctx.create_dataframe([[batch]])
@@ -38,5 +38,5 @@
3838
# execute and collect the first (and only) batch
3939
result = df.collect()[0]
4040

41-
assert result.column(0) == pyarrow.array([5, 7, 9])
42-
assert result.column(1) == pyarrow.array([-3, -3, -3])
41+
assert result.column(0) == pa.array([5, 7, 9])
42+
assert result.column(1) == pa.array([-3, -3, -3])

examples/sql-using-python-udaf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MyAccumulator(Accumulator):
2525
Interface of a user-defined accumulation.
2626
"""
2727

28-
def __init__(self):
28+
def __init__(self) -> None:
2929
self._sum = pa.scalar(0.0)
3030

3131
def update(self, values: pa.Array) -> None:

examples/tpch/_tests.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def check_q17(df):
9191
("q22_global_sales_opportunity", "q22"),
9292
],
9393
)
94-
def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
94+
def test_tpch_query_vs_answer_file(query_code: str, answer_file: str) -> None:
9595
module = import_module(query_code)
9696
df: DataFrame = module.df
9797

@@ -122,3 +122,5 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
122122

123123
assert df.join(df_expected, on=cols, how="anti").count() == 0
124124
assert df.count() == df_expected.count()
125+
126+
return None

0 commit comments

Comments
 (0)