Skip to content

Commit 3084f74

Browse files
fix: Reduce GIL hold time for IO plugins in new-streaming (#22186)
1 parent 86036a0 commit 3084f74

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

crates/polars-stream/src/physical_plan/to_graph.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -833,26 +833,9 @@ fn to_graph_rec<'a>(
833833
}?;
834834

835835
let get_batch_fn = Box::new(move |state: &StreamingExecutionState| {
836-
Python::with_gil(|py| {
836+
let df = Python::with_gil(|py| {
837837
match generator.bind(py).call_method0(intern!(py, "__next__")) {
838-
Ok(out) => {
839-
let mut df = polars_plan::plans::python_df_to_rust(py, out)?;
840-
if let (Some(pred), false) =
841-
(&pl_predicate, can_parse_predicate)
842-
{
843-
let mask =
844-
pred.evaluate(&df, &state.in_memory_exec_state)?;
845-
df = df.filter(mask.bool()?)?;
846-
}
847-
if validate_schema {
848-
polars_ensure!(
849-
df.schema() == &output_schema,
850-
SchemaMismatch: "user provided schema: {:?} doesn't match the DataFrame schema: {:?}",
851-
output_schema, df.schema()
852-
);
853-
}
854-
Ok(Some(df))
855-
},
838+
Ok(out) => polars_plan::plans::python_df_to_rust(py, out).map(Some),
856839
Err(err)
857840
if err.matches(py, PyStopIteration::type_object(py))? =>
858841
{
@@ -862,7 +845,26 @@ fn to_graph_rec<'a>(
862845
ComputeError: "caught exception during execution of a Python source, exception: {err}"
863846
),
864847
}
865-
})
848+
})?;
849+
850+
let Some(mut df) = df else { return Ok(None) };
851+
852+
if validate_schema {
853+
polars_ensure!(
854+
df.schema() == &output_schema,
855+
SchemaMismatch: "user provided schema: {:?} doesn't match the DataFrame schema: {:?}",
856+
output_schema, df.schema()
857+
);
858+
}
859+
860+
// TODO: Move this to a FilterNode so that it happens in parallel. We may need
861+
// to move all of the enclosing code to `lower_ir` for this.
862+
if let (Some(pred), false) = (&pl_predicate, can_parse_predicate) {
863+
let mask = pred.evaluate(&df, &state.in_memory_exec_state)?;
864+
df = df.filter(mask.bool()?)?;
865+
}
866+
867+
Ok(Some(df))
866868
}) as Box<_>;
867869

868870
(PlSmallStr::from_static("io_plugin"), get_batch_fn)

0 commit comments

Comments
 (0)