diff --git a/pytensor/printing.py b/pytensor/printing.py index bc42029c11..b7b71622e8 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -89,6 +89,7 @@ def debugprint( | Sequence[Variable | Apply | Function | FunctionGraph], depth: int = -1, print_type: bool = False, + print_shape: bool = False, file: Literal["str"] | TextIO | None = None, id_type: IDTypesType = "CHAR", stop_on_name: bool = False, @@ -98,6 +99,7 @@ def debugprint( print_op_info: bool = False, print_destroy_map: bool = False, print_view_map: bool = False, + print_memory_map: bool = False, print_fgraph_inputs: bool = False, ) -> str | TextIO: r"""Print a graph as text. @@ -123,6 +125,8 @@ def debugprint( Print graph to this depth (``-1`` for unlimited). print_type If ``True``, print the `Type`\s of each `Variable` in the graph. + print_shape + If ``True``, print the shape of each `Variable` in the graph. file When `file` extends `TextIO`, print to it; when `file` is equal to ``"str"``, return a string; when `file` is ``None``, print to @@ -153,6 +157,8 @@ def debugprint( Whether to print the `destroy_map`\s of printed objects print_view_map Whether to print the `view_map`\s of printed objects + print_memory_map + Whether to set both `print_destroy_map` and `print_view_map` to ``True``. print_fgraph_inputs Print the inputs of `FunctionGraph`\s. @@ -177,6 +183,10 @@ def debugprint( if used_ids is None: used_ids = dict() + if print_memory_map: + print_destroy_map = True + print_view_map = True + inputs_to_print = [] outputs_to_print = [] profile_list: list[Any | None] = [] @@ -265,6 +275,7 @@ def debugprint( depth=depth, done=done, print_type=print_type, + print_shape=print_shape, file=_file, id_type=id_type, inner_graph_ops=inner_graph_vars, @@ -295,6 +306,7 @@ def debugprint( depth=depth, done=done, print_type=print_type, + print_shape=print_shape, file=_file, topo_order=topo_order, id_type=id_type, @@ -365,6 +377,7 @@ def debugprint( depth=depth, done=done, print_type=print_type, + print_shape=print_shape, file=_file, id_type=id_type, inner_graph_ops=inner_graph_vars, @@ -387,6 +400,7 @@ def debugprint( depth=depth, done=done, print_type=print_type, + print_shape=print_shape, file=_file, id_type=id_type, stop_on_name=stop_on_name, @@ -421,6 +435,7 @@ def debugprint( depth=depth, done=done, print_type=print_type, + print_shape=print_shape, file=_file, id_type=id_type, stop_on_name=stop_on_name, @@ -452,6 +467,7 @@ def _debugprint( depth: int = -1, done: dict[Literal["output"] | Variable | Apply, str] | None = None, print_type: bool = False, + print_shape: bool = False, file: TextIO = sys.stdout, print_destroy_map: bool = False, print_view_map: bool = False, @@ -484,6 +500,8 @@ def _debugprint( See `debugprint`. print_type See `debugprint`. + print_shape + See `debugprint`. file File-like object to which to print. print_destroy_map @@ -532,6 +550,11 @@ def _debugprint( else: type_str = "" + if print_shape and hasattr(var.type, "shape"): + shape_str = f" shape={str(var.type.shape).replace('None', '?')}" + else: + shape_str = "" + if prefix_child is None: prefix_child = prefix @@ -612,7 +635,7 @@ def get_id_str( if is_inner_graph_header: var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}" else: - var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}" + var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{shape_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}" if print_op_info and node not in op_information: op_information.update(op_debug_information(node.op, node)) @@ -662,6 +685,7 @@ def get_id_str( depth=depth - 1, done=_done, print_type=print_type, + print_shape=print_shape, file=file, topo_order=topo_order, id_type=id_type, @@ -692,7 +716,7 @@ def get_id_str( else: data = "" - var_output = f"{prefix}{var}{id_str}{type_str}{data}" + var_output = f"{prefix}{var}{id_str}{type_str}{shape_str}{data}" if print_op_info and var.owner and var.owner not in op_information: op_information.update(op_debug_information(var.owner.op, var.owner)) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index b96113c8e3..937741c4cd 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -399,22 +399,13 @@ def __str__(self): else: shape = self.shape len_shape = len(shape) - - def shape_str(s): - if s is None: - return "?" - else: - return str(s) - - formatted_shape = ", ".join(shape_str(s) for s in shape) - if len_shape == 1: - formatted_shape += "," + formatted_shape = str(shape).replace("None", "?") if len_shape > 2: name = f"Tensor{len_shape}" else: name = ("Scalar", "Vector", "Matrix")[len_shape] - return f"{name}({self.dtype}, shape=({formatted_shape}))" + return f"{name}({self.dtype}, shape={formatted_shape})" def __repr__(self): return f"TensorType({self.dtype}, shape={self.shape})"