From 79b5b6ad565ad1f272cefc7c66d2193a2fd80e8a Mon Sep 17 00:00:00 2001 From: David Horsley Date: Wed, 24 May 2023 09:20:31 +1000 Subject: [PATCH 1/4] Remove unused type: ignores --- pytensor/tensor/rewriting/elemwise.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index afc51a9e3c..6bf4b5b902 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -349,7 +349,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1): inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise) -compile.optdb.register( # type: ignore +compile.optdb.register( "inplace_elemwise_opt", inplace_elemwise_optimizer, "inplace_opt", # for historic reason @@ -1097,7 +1097,7 @@ def print_profile(stream, prof, level=0): "fusion", position=1, ) - compile.optdb.register( # type: ignore + compile.optdb.register( "elemwise_fusion", fuse_seqopt, "fast_run", @@ -1211,7 +1211,7 @@ def local_careduce_fusion(fgraph, node): return [new_car_op(*elm_inputs)] -compile.optdb.register( # type: ignore +compile.optdb.register( "local_careduce_fusion", in2out(local_careduce_fusion), "fusion", @@ -1321,7 +1321,7 @@ def split_2f1grad_loop(fgraph, node): return replacements -compile.optdb["py_only"].register( # type: ignore +compile.optdb["py_only"].register( "split_2f1grad_loop", split_2f1grad_loop, "fast_compile", From 0272bba94cd7d94cfb99e51f13bd2a4b4c93b7a2 Mon Sep 17 00:00:00 2001 From: David Horsley Date: Sat, 20 May 2023 17:52:20 +1000 Subject: [PATCH 2/4] Split blas Ops and rewrites Having Ops and rewrites in the same files was causing circular imports. --- pytensor/tensor/blas.py | 560 +-------------- pytensor/tensor/blas_c.py | 72 -- pytensor/tensor/blas_scipy.py | 44 +- pytensor/tensor/rewriting/__init__.py | 3 + pytensor/tensor/rewriting/blas.py | 907 ++++++++++++++++++++++++ pytensor/tensor/rewriting/blas_c.py | 70 ++ pytensor/tensor/rewriting/blas_scipy.py | 37 + tests/tensor/test_blas.py | 3 +- 8 files changed, 1025 insertions(+), 671 deletions(-) create mode 100644 pytensor/tensor/rewriting/blas.py create mode 100644 pytensor/tensor/rewriting/blas_c.py create mode 100644 pytensor/tensor/rewriting/blas_scipy.py diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 1282cabae5..b276d7339b 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1,4 +1,4 @@ -"""Ops and optimizations for using BLAS calls +"""Ops for using BLAS calls BLAS = Basic Linear Algebra Subroutines Learn more about BLAS here: @@ -71,60 +71,10 @@ that system. -Optimizations -============= - -The optimization pipeline works something like this: - - 1. identify dot22 from dot - 2. identify gemm from dot22 - 3. identify dot22scalar from dot22 that are not gemm - 4. specialize gemm to gemv where applicable - 5. specialize gemm to ger where applicable - 6. specialize dot22 -> gemv or ger where applicable - -:note: GEMM is the most canonical BLAS signature that we deal with so far, it - would be good to turn most things into GEMM (dot, inner, outer, dot22, - dot22scalar), and then to specialize from gemm to the various other L2 and - L3 operations. - -Identify Dot22 --------------- - -Numpy's dot supports arguments that are of any rank, and we should support that -too (just for compatibility). The BLAS optimizations work with Dot Ops whose -inputs are each either vector or matrix. So the first part of the optimization -pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be -transformed further, but they will get implemented by a BLAS call. - -More precisely, Dot nodes whose inputs are all vectors or matrices and whose -inputs both have the same dtype, and whose dtype is float or complex, become -Dot22. This is implemented in `local_dot_to_dot22`. - - -Identify Gemm from Dot22 ------------------------- - -This is complicated, done in GemmOptimizer. - -Identify Dot22Scalar from Dot22 -------------------------------- - -Dot22 Ops that remain after the GemmOptimizer is done have not -qualified as GEMM Ops. Still they might be scaled by a factor, in -which case we use Dot22Scalar which is like Gemm, but without the b -and the Z. In the future it would be good to merge this into the -GemmOptimizer. - -Specialize Gemm to Gemv ------------------------ - -If arguments to GEMM are dimshuffled vectors, then we can use GEMV -instead. This optimization is `local_gemm_to_gemv`. +Optimizations associated with these BLAS Ops are in tensor.rewriting.blas """ -import copy import logging import os import time @@ -140,38 +90,20 @@ from typing import Tuple import pytensor.scalar -from pytensor.compile.mode import optdb from pytensor.configdefaults import config from pytensor.graph.basic import Apply, view_roots -from pytensor.graph.features import ReplacementDidNotRemoveError, ReplaceValidate from pytensor.graph.op import Op -from pytensor.graph.rewriting.basic import ( - EquilibriumGraphRewriter, - GraphRewriter, - copy_stack_trace, - in2out, - node_rewriter, -) -from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.printing import FunctionPrinter, debugprint, pprint +from pytensor.printing import FunctionPrinter, pprint from pytensor.scalar import bool as bool_t from pytensor.tensor import basic as at from pytensor.tensor.blas_headers import blas_header_text, blas_header_version -from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import Dot, add, mul, neg, sub -from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import add, mul, neg, sub from pytensor.tensor.shape import specify_broadcastable -from pytensor.tensor.type import ( - DenseTensorType, - TensorType, - integer_dtypes, - tensor, - values_eq_approx_remove_inf_nan, -) +from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor from pytensor.utils import memoize @@ -1512,150 +1444,6 @@ def _gemm_from_node2(fgraph, node): return None, t1 - t0, 0, 0 -class GemmOptimizer(GraphRewriter): - """Graph optimizer for inserting Gemm operations.""" - - def __init__(self): - super().__init__() - self.warned = False - - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) - - def apply(self, fgraph): - did_something = True - nb_iter = 0 - nb_replacement = 0 - nb_replacement_didn_t_remove = 0 - nb_inconsistency_make = 0 - nb_inconsistency_replace = 0 - time_canonicalize = 0 - time_factor_can = 0 - time_factor_list = 0 - time_toposort = 0 - if fgraph.profile: - validate_before = fgraph.profile.validate_time - callbacks_before = fgraph.execute_callbacks_times.copy() - callback_before = fgraph.execute_callbacks_time - - def on_import(new_node): - if new_node is not node: - nodelist.append(new_node) - - u = pytensor.graph.rewriting.basic.DispatchingFeature( - on_import, None, None, name="GemmOptimizer" - ) - fgraph.attach_feature(u) - while did_something: - nb_iter += 1 - t0 = time.perf_counter() - nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs) - time_toposort += time.perf_counter() - t0 - did_something = False - nodelist.reverse() - for node in nodelist: - if not ( - isinstance(node.op, Elemwise) - and isinstance( - node.op.scalar_op, - ( - pytensor.scalar.Add, - pytensor.scalar.Sub, - pytensor.scalar.Neg, - pytensor.scalar.Mul, - ), - ) - ): - continue - if node not in fgraph.apply_nodes: - # This mean that we already removed this node from - # the graph - continue - try: - new_outputs, time1, time2, time3 = _gemm_from_node2(fgraph, node) - time_canonicalize += time1 - time_factor_can += time2 - time_factor_list += time3 - except InconsistencyError: - nb_inconsistency_make += 1 - continue - if new_outputs: - new_outputs, old_dot22 = new_outputs - assert len(new_outputs) == len(node.outputs) - new_outputs[ - 0 - ].tag.values_eq_approx = values_eq_approx_remove_inf_nan - try: - fgraph.replace_all_validate_remove( - list(zip(node.outputs, new_outputs)), - [old_dot22], - reason="GemmOptimizer", - # For now we disable the warning as we know case - # that we need to fix. - warn=False, # warn=not self.warned - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - # TODO: retry other applications of gemm (see comment - # in _gemm_from_node) - nb_inconsistency_replace += 1 - except ReplacementDidNotRemoveError: - nb_replacement_didn_t_remove += 1 - self.warned = True - fgraph.remove_feature(u) - if fgraph.profile: - validate_time = fgraph.profile.validate_time - validate_before - callback_time = fgraph.execute_callbacks_time - callback_before - callbacks_time = {} - for k, v in fgraph.execute_callbacks_times.items(): - if k in callbacks_before: - callbacks_time[k] = v - callbacks_before[k] - else: - callbacks_time[k] = v - else: - validate_time = None - callback_time = None - callbacks_time = {} - - return ( - self, - nb_iter, - nb_replacement, - nb_replacement_didn_t_remove, - nb_inconsistency_make, - nb_inconsistency_replace, - time_canonicalize, - time_factor_can, - time_factor_list, - time_toposort, - validate_time, - callback_time, - callbacks_time, - ) - - @classmethod - def print_profile(cls, stream, prof, level=0): - blanc = " " * level - print(blanc, cls.__name__, file=stream) - print(blanc, " nb_iter", prof[1], file=stream) - print(blanc, " nb_replacement", prof[2], file=stream) - print(blanc, " nb_replacement_didn_t_remove", prof[3], file=stream) - print(blanc, " nb_inconsistency_make", prof[4], file=stream) - print(blanc, " nb_inconsistency_replace", prof[5], file=stream) - print(blanc, " time_canonicalize", prof[6], file=stream) - print(blanc, " time_factor_can", prof[7], file=stream) - print(blanc, " time_factor_list", prof[8], file=stream) - print(blanc, " time_toposort", prof[9], file=stream) - print(blanc, " validate_time", prof[10], file=stream) - print(blanc, " callback_time", prof[11], file=stream) - if prof[11] > 1: - print(blanc, " callbacks_time", file=stream) - for i in sorted(prof[12].items(), key=lambda a: a[1]): - if i[1] > 0: - print(i) - - class Dot22(GemmRelated): """Compute a matrix-matrix product. @@ -1750,207 +1538,6 @@ def c_code_cache_version(self): _dot22 = Dot22() -@node_rewriter([Dot]) -def local_dot_to_dot22(fgraph, node): - # This works for tensor.outer too because basic.outer is a macro that - # produces a dot(dimshuffle,dimshuffle) of form 4 below - if not isinstance(node.op, Dot): - return - - if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): - return False - - x, y = node.inputs - if y.type.dtype != x.type.dtype: - # TODO: upcast one so the types match - _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") - return - - if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): - if x.ndim == 2 and y.ndim == 2: - new_out = [_dot22(*node.inputs)] - elif x.ndim == 2 and y.ndim == 1: - new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] - elif x.ndim == 1 and y.ndim == 2: - new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] - elif x.ndim == 1 and y.ndim == 1: - new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out - - _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") - - -@node_rewriter([gemm_no_inplace], inplace=True) -def local_inplace_gemm(fgraph, node): - if node.op == gemm_no_inplace: - new_out = [gemm_inplace(*node.inputs)] - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([gemv_no_inplace], inplace=True) -def local_inplace_gemv(fgraph, node): - if node.op == gemv_no_inplace: - new_out = [gemv_inplace(*node.inputs)] - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([ger], inplace=True) -def local_inplace_ger(fgraph, node): - if node.op == ger: - new_out = [ger_destructive(*node.inputs)] - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([gemm_no_inplace]) -def local_gemm_to_gemv(fgraph, node): - """GEMM acting on row or column matrices -> GEMV.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if z.broadcastable == x.broadcastable == (True, False): - r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) - new_out = [r.dimshuffle("x", 0)] - elif z.broadcastable == y.broadcastable == (False, True): - r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) - new_out = [r.dimshuffle(0, "x")] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([gemm_no_inplace]) -def local_gemm_to_ger(fgraph, node): - """GEMM computing an outer-product -> GER.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if x.broadcastable[1] and y.broadcastable[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - try: - bval = at.get_underlying_scalar_constant_value(b) - except NotScalarConstantError: - # b isn't a constant, GEMM is doing useful pre-scaling - return - - if bval == 1: # best case a natural GER - rval = ger(z, a, xv, yv) - new_out = [rval] - elif bval == 0: # GER on zeros_like should be faster than GEMM - zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype) - rval = ger(zeros, a, xv, yv) - new_out = [rval] - else: - # if bval is another constant, then z is being usefully - # pre-scaled and GER isn't really the right tool for the job. - return - copy_stack_trace(node.outputs, new_out) - return new_out - - -# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline -# working -@node_rewriter([_dot22]) -def local_dot22_to_ger_or_gemv(fgraph, node): - """dot22 computing an outer-product -> GER.""" - if node.op == _dot22: - x, y = node.inputs - xb = x.broadcastable - yb = y.broadcastable - one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype)) - zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype)) - if xb[1] and yb[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) - rval = ger(zeros, one, xv, yv) - new_out = [rval] - elif xb[0] and yb[1]: - # x and y are both vectors so this qualifies for a sdot / ddot - # TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22 - xv = x.dimshuffle(1) - zeros = at.AllocEmpty(x.dtype)(1) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif xb[0] and not yb[0] and not yb[1]: - # x is vector, y is matrix so try gemv - xv = x.dimshuffle(1) - zeros = at.AllocEmpty(x.dtype)(y.shape[1]) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif not xb[0] and not xb[1] and yb[1]: - # x is matrix, y is vector, try gemv - yv = y.dimshuffle(0) - zeros = at.AllocEmpty(x.dtype)(x.shape[0]) - rval = gemv_no_inplace(zeros, one, x, yv, zero) - new_out = [rval.dimshuffle(0, "x")] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out - - -################################# -# -# Set up the BlasOpt optimizer -# -################################# - -blas_optdb = SequenceDB() - -# run after numerical stability optimizations (1.5) -optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7) -# run before specialize (2.0) because specialize is basically a -# free-for-all that makes the graph crazy. - -# fast_compile is needed to have GpuDot22 created. -blas_optdb.register( - "local_dot_to_dot22", - in2out(local_dot_to_dot22), - "fast_run", - "fast_compile", - position=0, -) -blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10) -blas_optdb.register( - "local_gemm_to_gemv", - EquilibriumGraphRewriter( - [ - local_gemm_to_gemv, - local_gemm_to_ger, - local_dot22_to_ger_or_gemv, - local_dimshuffle_lift, - ], - max_use_ratio=5, - ignore_newtrees=False, - ), - "fast_run", - position=15, -) - - -# After destroyhandler(49.5) but before we try to make elemwise things -# inplace (75) -blas_opt_inplace = in2out( - local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace" -) -optdb.register( - "InplaceBlasOpt", - blas_opt_inplace, - "fast_run", - "inplace", - "blas_opt_inplace", - position=70.0, -) - - class Dot22Scalar(GemmRelated): """Compute a matrix-matrix product. @@ -2049,133 +1636,6 @@ def c_code_cache_version(self): _dot22scalar = Dot22Scalar() -@node_rewriter([mul]) -def local_dot22_to_dot22scalar(fgraph, node): - """ - Notes - ----- - Previous attempts to alter this optimization to replace dot22 with - gemm instead of dot22scalar resulted in some Scan nodes being - duplicated and the ScanSaveMem optimization never running on them, - resulting in highly increased memory usage. Until this issue is - resolved, this optimization should keep using dot22scalar instead of - gemm. - - We upcast the scalar if after the multiplication with the dot this give - the same type. - - We execute this optimizer after the gemm optimizer. This - allow to give more priority to gemm that give more speed up - then this optimizer, but allow the gemm optimizer to ignore - this op. - - TODO: support when we can reorder the mul to generate a - dot22scalar or fix the canonizer to merge them(1 mul with multiple - inputs) - - """ - if node.op != mul: - return False - i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs] - if not any(i_dot22): - return False # no dot22 - if i_dot22.count(True) > 1: - # TODO: try each of them. - pass - # return False #TODO fix - dot22_idx = i_dot22.index(True) - d = node.inputs[dot22_idx] - i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] - if not any(i_scalar): - # Check if we can reorder the graph as this mul have a mul in inputs. - # We support only 1 additional level of mul. - # The canonizer should have merged those mul together. - i_mul = [ - x.owner - and x.owner.op == mul - and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs) - for x in node.inputs - ] - if not any(i_mul): - # no scalar in input and no multiplication - # if their was a multiplication we couls reorder the graph - # by the associativity of the graph. - return False - - mul_idx = i_mul.index(True) # The first one should always work - m = node.inputs[mul_idx] - - scalar_idx = -1 - for i, x in enumerate(m.owner.inputs): - if _as_scalar(x, dtype=d.dtype) and ( - pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype - ): - scalar_idx = i - break - - if scalar_idx < 0: - _logger.info( - f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the" - " type of the scalar cannot be upcasted to the" - " matrix type" - ) - return False - a = at.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype) - assert not a.type.ndim - dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) - - # The other inputs to the original node that were - # neither part of the dot22 or this mul should be - # factors in the returned "mul" node. - assert dot22_idx != mul_idx - other_factors = [ - inpt for i, inpt in enumerate(node.inputs) if i not in (dot22_idx, mul_idx) - ] - other_m_inputs = [ - inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx - ] - - return [mul(dot, *(other_factors + other_m_inputs))] - - scalar_idx = -1 - for i, x in enumerate(node.inputs): - if ( - i != dot22_idx - and i_scalar[i] is not None - and (pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype) - ): - scalar_idx = i - break - if scalar_idx < 0: - _logger.info( - f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the type " - "of the scalar cannot be upcasted to the matrix type" - ) - return False - assert scalar_idx < len(node.inputs) - s = node.inputs[scalar_idx] - o = copy.copy(node.inputs) - o.remove(d) - o.remove(s) - - a = at.cast(i_scalar[scalar_idx], d.type.dtype) - assert not a.type.ndim - if len(o) == 0: - return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)] - else: - return [mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)] - - -# must happen after gemm as the gemm optimizer don't understant -# dot22scalar and gemm give more speed up then dot22scalar -blas_optdb.register( - "local_dot22_to_dot22scalar", - in2out(local_dot22_to_dot22scalar), - "fast_run", - position=11, -) - - class BatchedDot(COp): """ Computes the batched dot product of two variables: @@ -2669,14 +2129,6 @@ def infer_shape(self, fgraph, node, shapes): _batched_dot = BatchedDot() -# from opt import register_specialize, register_canonicalize -# @register_specialize -@node_rewriter([sub, add]) -def local_print_as_we_go_along(fgraph, node): - if node.op in (sub, add): - debugprint(node) - - def batched_dot(a, b): """Compute the batched dot product of two variables. diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index e4e90066b0..704970b5ef 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -1,22 +1,12 @@ -from pytensor.configdefaults import config -from pytensor.graph.rewriting.basic import in2out from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.scalar import bool as bool_t -from pytensor.tensor import basic as at from pytensor.tensor.blas import ( Gemv, Ger, blas_header_text, blas_header_version, - blas_optdb, - gemv_inplace, - gemv_no_inplace, - ger, - ger_destructive, ldflags, - node_rewriter, - optdb, ) @@ -344,23 +334,6 @@ def c_code_cache_version(self): cger_no_inplace = CGer(False) -@node_rewriter([ger, ger_destructive]) -def use_c_ger(fgraph, node): - if not config.blas__ldflags: - return - # Only float32 and float64 are supported for now. - if node.op == ger and node.outputs[0].dtype in ("float32", "float64"): - return [CGer(False)(*node.inputs)] - if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"): - return [CGer(True)(*node.inputs)] - - -@node_rewriter([CGer(False)]) -def make_c_ger_destructive(fgraph, node): - if isinstance(node.op, CGer) and not node.op.destructive: - return [cger_inplace(*node.inputs)] - - # ##### ####### ####### # GEMV # ##### ####### ####### @@ -697,48 +670,3 @@ def check_force_gemv_init(): check_force_gemv_init._force_init_beta = None - - -@node_rewriter([gemv_inplace, gemv_no_inplace]) -def use_c_gemv(fgraph, node): - if not config.blas__ldflags: - return - # Only float32 and float64 are supported for now. - if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"): - return [cgemv_no_inplace(*node.inputs)] - if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"): - return [cgemv_inplace(*node.inputs)] - - -@node_rewriter([CGemv(inplace=False)]) -def make_c_gemv_destructive(fgraph, node): - if isinstance(node.op, CGemv) and not node.op.inplace: - inputs = list(node.inputs) - dest = inputs[0] - if ( - dest.owner - and isinstance(dest.owner.op, at.AllocEmpty) - and len(fgraph.clients[dest]) > 1 - ): - inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs) - - return [cgemv_inplace(*inputs)] - - -# ##### ####### ####### -# Optimizers -# ##### ####### ####### - -blas_optdb.register( - "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 -) - -# this matches the InplaceBlasOpt defined in blas.py -optdb.register( - "c_blas_destructive", - in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), - "fast_run", - "inplace", - "c_blas", - position=70.0, -) diff --git a/pytensor/tensor/blas_scipy.py b/pytensor/tensor/blas_scipy.py index 4d1be6e322..527d5150a1 100644 --- a/pytensor/tensor/blas_scipy.py +++ b/pytensor/tensor/blas_scipy.py @@ -4,16 +4,7 @@ import numpy as np -from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor.blas import ( - Ger, - blas_optdb, - ger, - ger_destructive, - have_fblas, - node_rewriter, - optdb, -) +from pytensor.tensor.blas import Ger, have_fblas if have_fblas: @@ -56,36 +47,3 @@ def perform(self, node, inputs, output_storage): scipy_ger_no_inplace = ScipyGer(False) scipy_ger_inplace = ScipyGer(True) - - -@node_rewriter([ger, ger_destructive]) -def use_scipy_ger(fgraph, node): - if node.op == ger: - return [scipy_ger_no_inplace(*node.inputs)] - - -@node_rewriter([scipy_ger_no_inplace]) -def make_ger_destructive(fgraph, node): - if node.op == scipy_ger_no_inplace: - return [scipy_ger_inplace(*node.inputs)] - - -use_scipy_blas = in2out(use_scipy_ger) -make_scipy_blas_destructive = in2out(make_ger_destructive) - -if have_fblas: - # scipy_blas is scheduled in the blas_optdb very late, because scipy sortof - # sucks, but it is almost always present. - # C implementations should be scheduled earlier than this, so that they take - # precedence. Once the original Ger is replaced, then these optimizations - # have no effect. - blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) - - # this matches the InplaceBlasOpt defined in blas.py - optdb.register( - "make_scipy_blas_destructive", - make_scipy_blas_destructive, - "fast_run", - "inplace", - position=70.0, - ) diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index cb244afb7e..d8836e4b7b 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -1,4 +1,7 @@ import pytensor.tensor.rewriting.basic +import pytensor.tensor.rewriting.blas +import pytensor.tensor.rewriting.blas_c +import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py new file mode 100644 index 0000000000..a310cb5837 --- /dev/null +++ b/pytensor/tensor/rewriting/blas.py @@ -0,0 +1,907 @@ +"""optimizations for using BLAS calls + +Optimizations +============= + +The optimization pipeline works something like this: + + 1. identify dot22 from dot + 2. identify gemm from dot22 + 3. identify dot22scalar from dot22 that are not gemm + 4. specialize gemm to gemv where applicable + 5. specialize gemm to ger where applicable + 6. specialize dot22 -> gemv or ger where applicable + +:note: GEMM is the most canonical BLAS signature that we deal with so far, it + would be good to turn most things into GEMM (dot, inner, outer, dot22, + dot22scalar), and then to specialize from gemm to the various other L2 and + L3 operations. + +Identify Dot22 +-------------- + +Numpy's dot supports arguments that are of any rank, and we should support that +too (just for compatibility). The BLAS optimizations work with Dot Ops whose +inputs are each either vector or matrix. So the first part of the optimization +pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be +transformed further, but they will get implemented by a BLAS call. + +More precisely, Dot nodes whose inputs are all vectors or matrices and whose +inputs both have the same dtype, and whose dtype is float or complex, become +Dot22. This is implemented in `local_dot_to_dot22`. + + +Identify Gemm from Dot22 +------------------------ + +This is complicated, done in GemmOptimizer. + +Identify Dot22Scalar from Dot22 +------------------------------- + +Dot22 Ops that remain after the GemmOptimizer is done have not +qualified as GEMM Ops. Still they might be scaled by a factor, in +which case we use Dot22Scalar which is like Gemm, but without the b +and the Z. In the future it would be good to merge this into the +GemmOptimizer. + +Specialize Gemm to Gemv +----------------------- + +If arguments to GEMM are dimshuffled vectors, then we can use GEMV +instead. This optimization is `local_gemm_to_gemv`. + +""" + +import copy +import logging +import time + +import numpy as np + + +try: + import numpy.__config__ # noqa +except ImportError: + pass + + +import pytensor.scalar +from pytensor.compile.mode import optdb +from pytensor.configdefaults import config +from pytensor.graph.features import ReplacementDidNotRemoveError, ReplaceValidate +from pytensor.graph.rewriting.basic import ( + EquilibriumGraphRewriter, + GraphRewriter, + copy_stack_trace, + in2out, + node_rewriter, +) +from pytensor.graph.rewriting.db import SequenceDB +from pytensor.graph.utils import InconsistencyError +from pytensor.printing import debugprint +from pytensor.tensor import basic as at +from pytensor.tensor.blas import ( + Dot22, + _dot22, + _dot22scalar, + gemm_inplace, + gemm_no_inplace, + gemv_inplace, + gemv_no_inplace, + ger, + ger_destructive, +) +from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.math import Dot, add, mul, neg, sub +from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift +from pytensor.tensor.type import ( + DenseTensorType, + TensorType, + integer_dtypes, + values_eq_approx_remove_inf_nan, +) + + +_logger = logging.getLogger("pytensor.tensor.rewriting.blas") + + +def res_is_a(fgraph, var, op, maxclients=None): + if maxclients is not None and var in fgraph.clients: + retval = len(fgraph.get_clients(var)) <= maxclients + else: + retval = True + + return var.owner and var.owner.op == op and retval + + +def _as_scalar(res, dtype=None): + """Return ``None`` or a `TensorVariable` of float type""" + if dtype is None: + dtype = config.floatX + if all(s == 1 for s in res.type.shape): + while res.owner and isinstance(res.owner.op, DimShuffle): + res = res.owner.inputs[0] + # may still have some number of True's + if res.type.ndim > 0: + rval = res.dimshuffle() + else: + rval = res + if rval.type.dtype in integer_dtypes: + # We check that the upcast of res and dtype won't change dtype. + # If dtype is float64, we will cast int64 to float64. + # This is valid when res is a scalar used as input to a dot22 + # as the cast of the scalar can be done before or after the dot22 + # and this will give the same result. + if pytensor.scalar.upcast(res.dtype, dtype) == dtype: + return at.cast(rval, dtype) + else: + return None + + return rval + + +def _is_real_matrix(res): + return ( + res.type.dtype in ("float16", "float32", "float64") + and res.type.ndim == 2 + and res.type.shape[0] != 1 + and res.type.shape[1] != 1 + ) # cope with tuple vs. list + + +def _is_real_vector(res): + return ( + res.type.dtype in ("float16", "float32", "float64") + and res.type.ndim == 1 + and res.type.shape[0] != 1 + ) + + +def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True): + # print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip + # EXPRESSION: (beta * L) + (alpha * M) + + # we've already checked the client counts, now just make the type check. + # if res_is_a(M, _dot22, 1): + if M.owner and M.owner.op == _dot22: + Ml, Mr = M.owner.inputs + rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] + return rval, M + + # it also might be the case that there is a dimshuffle between the + + # and the dot22. local_dot_to_dot22 in particular will put in such things. + if ( + M.owner + and isinstance(M.owner.op, DimShuffle) + and M.owner.inputs[0].owner + and isinstance(M.owner.inputs[0].owner.op, Dot22) + ): + MM = M.owner.inputs[0] + if M.owner.op.new_order == (0,): + # it is making a column MM into a vector + MMl, MMr = MM.owner.inputs + g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta) + rval = [g.dimshuffle(0)] + return rval, MM + if M.owner.op.new_order == (1,): + # it is making a row MM into a vector + MMl, MMr = MM.owner.inputs + g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta) + rval = [g.dimshuffle(1)] + return rval, MM + if len(M.owner.op.new_order) == 0: + # it is making a row MM into a vector + MMl, MMr = MM.owner.inputs + g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta) + rval = [g.dimshuffle()] + return rval, MM + + if recurse_flip: + return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False) + else: + return False, False + + +def _gemm_canonicalize(fgraph, r, scale, rval, maxclients): + # Tries to interpret node as a sum of scalars * (vectors or matrices) + def scaled(thing): + if scale == 1: + return thing + if scale == -1 and thing.type.dtype != "bool": + return -thing + else: + return scale * thing + + if not isinstance(r.type, TensorType): + return None + + if (r.type.ndim not in (1, 2)) or r.type.dtype not in ( + "float16", + "float32", + "float64", + "complex64", + "complex128", + ): + rval.append(scaled(r)) + return rval + + if maxclients and len(fgraph.clients[r]) > maxclients: + rval.append((scale, r)) + return rval + + if r.owner and r.owner.op == sub: + _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1) + _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1) + + elif r.owner and r.owner.op == add: + for i in r.owner.inputs: + _gemm_canonicalize(fgraph, i, scale, rval, 1) + + elif r.owner and r.owner.op == neg: + _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1) + + elif r.owner and r.owner.op == mul: + scalars = [] + vectors = [] + matrices = [] + for i in r.owner.inputs: + if all(s == 1 for s in i.type.shape): + while i.owner and isinstance(i.owner.op, DimShuffle): + i = i.owner.inputs[0] + if i.type.ndim > 0: + scalars.append(i.dimshuffle()) + else: + scalars.append(i) + elif _is_real_vector(i): + vectors.append(i) + elif _is_real_matrix(i): + matrices.append(i) + else: + # just put the original arguments as in the base case + rval.append((scale, r)) + return rval + if len(matrices) == 1: + assert len(vectors) == 0 + m = matrices[0] + if len(scalars) == 0: + _gemm_canonicalize(fgraph, m, scale, rval, 1) + elif len(scalars) == 1: + _gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1) + else: + _gemm_canonicalize( + fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 + ) + elif len(vectors) == 1: + assert len(matrices) == 0 + v = vectors[0] + if len(scalars) == 0: + _gemm_canonicalize(fgraph, v, scale, rval, 1) + elif len(scalars) == 1: + _gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1) + else: + _gemm_canonicalize( + fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 + ) + else: # lets not open this up + rval.append((scale, r)) + else: + rval.append((scale, r)) + return rval + + +def _factor_canonicalized(lst): + # remove duplicates from canonicalized list + + # we only delete out of the right end of the list, + # once i has touched a list element, it is permantent + lst = list(lst) + # print 'FACTOR', lst + # for t in lst: + # if not isinstance(t, (list, tuple)): + # t = (t,) + # for e in t: + # try: + # pytensor.printing.debugprint(e) + # except TypeError: + # print e, type(e) + i = 0 + while i < len(lst) - 1: + try: + s_i, M_i = lst[i] + except Exception: + i += 1 + continue + + j = i + 1 + while j < len(lst): + try: + s_j, M_j = lst[j] + except Exception: + j += 1 + continue + + if M_i is M_j: + s_i = s_i + s_j + lst[i] = (s_i, M_i) + del lst[j] + else: + j += 1 + i += 1 + return lst + + +def _gemm_from_factored_list(fgraph, lst): + """ + Returns None, or a list to replace node.outputs. + + """ + lst2 = [] + # Remove the tuple that can't be cast correctly. + # This can happen when we try to cast a complex to a real + for sM in lst: + # Make every pair in list have matching dtypes + # sM can be a tuple of 2 elements or an PyTensor variable. + if isinstance(sM, tuple): + sm0, sm1 = sM + sm0 = at.as_tensor_variable(sm0) + if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: + lst2.append((at.cast(sm0, sm1.dtype), sM[1])) + + lst = lst2 + + def item_to_var(t): + try: + s, M = t + except Exception: + return t + if s == 1: + return M + if s == -1: + return -M + return s * M + + # Try every pair in the sM_list, trying to turn it into a gemm operation + for i in range(len(lst) - 1): + s_i, M_i = lst[i] + + for j in range(i + 1, len(lst)): + s_j, M_j = lst[j] + + if not M_j.type.in_same_class(M_i.type): + continue + + # print 'TRYING', (s_i, M_i, s_j, M_j) + + gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M( + fgraph, s_i, M_i, s_j, M_j + ) + # print 'GOT IT', gemm_of_sM_list + if gemm_of_sM_list: + assert len(gemm_of_sM_list) == 1 + add_inputs = [ + item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) + ] + add_inputs.extend(gemm_of_sM_list) + if len(add_inputs) > 1: + rval = [add(*add_inputs)] + else: + rval = add_inputs + # print "RETURNING GEMM THING", rval + return rval, old_dot22 + + +def _gemm_from_node2(fgraph, node): + """ + + TODO: In many expressions, there are many ways to turn it into a + gemm. For example dot(a,b) + c + d. This function should return all + of them, so that if one version of gemm causes a cycle in the graph, then + another application of gemm can be tried. + + """ + lst = [] + t0 = time.perf_counter() + _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0) + t1 = time.perf_counter() + + if len(lst) > 1: + lst = _factor_canonicalized(lst) + t2 = time.perf_counter() + rval = _gemm_from_factored_list(fgraph, lst) + t3 = time.perf_counter() + + # It can happen that _factor_canonicalized and + # _gemm_from_factored_list return a node with an incorrect + # type. This happens in particular when one of the scalar + # factors forces the upcast of the whole expression. In that + # case, we simply skip that candidate for Gemm. This was + # discussed in + # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, + # but never made it into a trac ticket. + + if rval and rval[0][0].type.in_same_class(node.outputs[0].type): + return rval, t1 - t0, t2 - t1, t3 - t2 + + return None, t1 - t0, 0, 0 + + +class GemmOptimizer(GraphRewriter): + """Graph optimizer for inserting Gemm operations.""" + + def __init__(self): + super().__init__() + self.warned = False + + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def apply(self, fgraph): + did_something = True + nb_iter = 0 + nb_replacement = 0 + nb_replacement_didn_t_remove = 0 + nb_inconsistency_make = 0 + nb_inconsistency_replace = 0 + time_canonicalize = 0 + time_factor_can = 0 + time_factor_list = 0 + time_toposort = 0 + if fgraph.profile: + validate_before = fgraph.profile.validate_time + callbacks_before = fgraph.execute_callbacks_times.copy() + callback_before = fgraph.execute_callbacks_time + + def on_import(new_node): + if new_node is not node: + nodelist.append(new_node) + + u = pytensor.graph.rewriting.basic.DispatchingFeature( + on_import, None, None, name="GemmOptimizer" + ) + fgraph.attach_feature(u) + while did_something: + nb_iter += 1 + t0 = time.perf_counter() + nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs) + time_toposort += time.perf_counter() - t0 + did_something = False + nodelist.reverse() + for node in nodelist: + if not ( + isinstance(node.op, Elemwise) + and isinstance( + node.op.scalar_op, + ( + pytensor.scalar.Add, + pytensor.scalar.Sub, + pytensor.scalar.Neg, + pytensor.scalar.Mul, + ), + ) + ): + continue + if node not in fgraph.apply_nodes: + # This mean that we already removed this node from + # the graph + continue + try: + new_outputs, time1, time2, time3 = _gemm_from_node2(fgraph, node) + time_canonicalize += time1 + time_factor_can += time2 + time_factor_list += time3 + except InconsistencyError: + nb_inconsistency_make += 1 + continue + if new_outputs: + new_outputs, old_dot22 = new_outputs + assert len(new_outputs) == len(node.outputs) + new_outputs[ + 0 + ].tag.values_eq_approx = values_eq_approx_remove_inf_nan + try: + fgraph.replace_all_validate_remove( + list(zip(node.outputs, new_outputs)), + [old_dot22], + reason="GemmOptimizer", + # For now we disable the warning as we know case + # that we need to fix. + warn=False, # warn=not self.warned + ) + did_something = True + nb_replacement += 1 + except InconsistencyError: + # TODO: retry other applications of gemm (see comment + # in _gemm_from_node) + nb_inconsistency_replace += 1 + except ReplacementDidNotRemoveError: + nb_replacement_didn_t_remove += 1 + self.warned = True + fgraph.remove_feature(u) + if fgraph.profile: + validate_time = fgraph.profile.validate_time - validate_before + callback_time = fgraph.execute_callbacks_time - callback_before + callbacks_time = {} + for k, v in fgraph.execute_callbacks_times.items(): + if k in callbacks_before: + callbacks_time[k] = v - callbacks_before[k] + else: + callbacks_time[k] = v + else: + validate_time = None + callback_time = None + callbacks_time = {} + + return ( + self, + nb_iter, + nb_replacement, + nb_replacement_didn_t_remove, + nb_inconsistency_make, + nb_inconsistency_replace, + time_canonicalize, + time_factor_can, + time_factor_list, + time_toposort, + validate_time, + callback_time, + callbacks_time, + ) + + @classmethod + def print_profile(cls, stream, prof, level=0): + blanc = " " * level + print(blanc, cls.__name__, file=stream) + print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_replacement", prof[2], file=stream) + print(blanc, " nb_replacement_didn_t_remove", prof[3], file=stream) + print(blanc, " nb_inconsistency_make", prof[4], file=stream) + print(blanc, " nb_inconsistency_replace", prof[5], file=stream) + print(blanc, " time_canonicalize", prof[6], file=stream) + print(blanc, " time_factor_can", prof[7], file=stream) + print(blanc, " time_factor_list", prof[8], file=stream) + print(blanc, " time_toposort", prof[9], file=stream) + print(blanc, " validate_time", prof[10], file=stream) + print(blanc, " callback_time", prof[11], file=stream) + if prof[11] > 1: + print(blanc, " callbacks_time", file=stream) + for i in sorted(prof[12].items(), key=lambda a: a[1]): + if i[1] > 0: + print(i) + + +@node_rewriter([Dot]) +def local_dot_to_dot22(fgraph, node): + # This works for tensor.outer too because basic.outer is a macro that + # produces a dot(dimshuffle,dimshuffle) of form 4 below + if not isinstance(node.op, Dot): + return + + if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): + return False + + x, y = node.inputs + if y.type.dtype != x.type.dtype: + # TODO: upcast one so the types match + _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") + return + + if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): + if x.ndim == 2 and y.ndim == 2: + new_out = [_dot22(*node.inputs)] + elif x.ndim == 2 and y.ndim == 1: + new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] + elif x.ndim == 1 and y.ndim == 2: + new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] + elif x.ndim == 1 and y.ndim == 1: + new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") + + +@node_rewriter([gemm_no_inplace], inplace=True) +def local_inplace_gemm(fgraph, node): + if node.op == gemm_no_inplace: + new_out = [gemm_inplace(*node.inputs)] + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([gemv_no_inplace], inplace=True) +def local_inplace_gemv(fgraph, node): + if node.op == gemv_no_inplace: + new_out = [gemv_inplace(*node.inputs)] + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([ger], inplace=True) +def local_inplace_ger(fgraph, node): + if node.op == ger: + new_out = [ger_destructive(*node.inputs)] + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([gemm_no_inplace]) +def local_gemm_to_gemv(fgraph, node): + """GEMM acting on row or column matrices -> GEMV.""" + if node.op == gemm_no_inplace: + z, a, x, y, b = node.inputs + if z.broadcastable == x.broadcastable == (True, False): + r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) + new_out = [r.dimshuffle("x", 0)] + elif z.broadcastable == y.broadcastable == (False, True): + r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) + new_out = [r.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([gemm_no_inplace]) +def local_gemm_to_ger(fgraph, node): + """GEMM computing an outer-product -> GER.""" + if node.op == gemm_no_inplace: + z, a, x, y, b = node.inputs + if x.broadcastable[1] and y.broadcastable[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + try: + bval = at.get_underlying_scalar_constant_value(b) + except NotScalarConstantError: + # b isn't a constant, GEMM is doing useful pre-scaling + return + + if bval == 1: # best case a natural GER + rval = ger(z, a, xv, yv) + new_out = [rval] + elif bval == 0: # GER on zeros_like should be faster than GEMM + zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype) + rval = ger(zeros, a, xv, yv) + new_out = [rval] + else: + # if bval is another constant, then z is being usefully + # pre-scaled and GER isn't really the right tool for the job. + return + copy_stack_trace(node.outputs, new_out) + return new_out + + +# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline +# working +@node_rewriter([_dot22]) +def local_dot22_to_ger_or_gemv(fgraph, node): + """dot22 computing an outer-product -> GER.""" + if node.op == _dot22: + x, y = node.inputs + xb = x.broadcastable + yb = y.broadcastable + one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype)) + zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype)) + if xb[1] and yb[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) + rval = ger(zeros, one, xv, yv) + new_out = [rval] + elif xb[0] and yb[1]: + # x and y are both vectors so this qualifies for a sdot / ddot + # TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22 + xv = x.dimshuffle(1) + zeros = at.AllocEmpty(x.dtype)(1) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif xb[0] and not yb[0] and not yb[1]: + # x is vector, y is matrix so try gemv + xv = x.dimshuffle(1) + zeros = at.AllocEmpty(x.dtype)(y.shape[1]) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif not xb[0] and not xb[1] and yb[1]: + # x is matrix, y is vector, try gemv + yv = y.dimshuffle(0) + zeros = at.AllocEmpty(x.dtype)(x.shape[0]) + rval = gemv_no_inplace(zeros, one, x, yv, zero) + new_out = [rval.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + +################################# +# +# Set up the BlasOpt optimizer +# +################################# + +blas_optdb = SequenceDB() + +# run after numerical stability optimizations (1.5) +optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7) +# run before specialize (2.0) because specialize is basically a +# free-for-all that makes the graph crazy. + +# fast_compile is needed to have GpuDot22 created. +blas_optdb.register( + "local_dot_to_dot22", + in2out(local_dot_to_dot22), + "fast_run", + "fast_compile", + position=0, +) +blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10) +blas_optdb.register( + "local_gemm_to_gemv", + EquilibriumGraphRewriter( + [ + local_gemm_to_gemv, + local_gemm_to_ger, + local_dot22_to_ger_or_gemv, + local_dimshuffle_lift, + ], + max_use_ratio=5, + ignore_newtrees=False, + ), + "fast_run", + position=15, +) + + +# After destroyhandler(49.5) but before we try to make elemwise things +# inplace (75) +blas_opt_inplace = in2out( + local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace" +) +optdb.register( + "InplaceBlasOpt", + blas_opt_inplace, + "fast_run", + "inplace", + "blas_opt_inplace", + position=70.0, +) + + +@node_rewriter([mul]) +def local_dot22_to_dot22scalar(fgraph, node): + """ + Notes + ----- + Previous attempts to alter this optimization to replace dot22 with + gemm instead of dot22scalar resulted in some Scan nodes being + duplicated and the ScanSaveMem optimization never running on them, + resulting in highly increased memory usage. Until this issue is + resolved, this optimization should keep using dot22scalar instead of + gemm. + + We upcast the scalar if after the multiplication with the dot this give + the same type. + + We execute this optimizer after the gemm optimizer. This + allow to give more priority to gemm that give more speed up + then this optimizer, but allow the gemm optimizer to ignore + this op. + + TODO: support when we can reorder the mul to generate a + dot22scalar or fix the canonizer to merge them(1 mul with multiple + inputs) + + """ + if node.op != mul: + return False + i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs] + if not any(i_dot22): + return False # no dot22 + if i_dot22.count(True) > 1: + # TODO: try each of them. + pass + # return False #TODO fix + dot22_idx = i_dot22.index(True) + d = node.inputs[dot22_idx] + i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] + if not any(i_scalar): + # Check if we can reorder the graph as this mul have a mul in inputs. + # We support only 1 additional level of mul. + # The canonizer should have merged those mul together. + i_mul = [ + x.owner + and x.owner.op == mul + and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs) + for x in node.inputs + ] + if not any(i_mul): + # no scalar in input and no multiplication + # if their was a multiplication we couls reorder the graph + # by the associativity of the graph. + return False + + mul_idx = i_mul.index(True) # The first one should always work + m = node.inputs[mul_idx] + + scalar_idx = -1 + for i, x in enumerate(m.owner.inputs): + if _as_scalar(x, dtype=d.dtype) and ( + pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype + ): + scalar_idx = i + break + + if scalar_idx < 0: + _logger.info( + f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the" + " type of the scalar cannot be upcasted to the" + " matrix type" + ) + return False + a = at.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype) + assert not a.type.ndim + dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) + + # The other inputs to the original node that were + # neither part of the dot22 or this mul should be + # factors in the returned "mul" node. + assert dot22_idx != mul_idx + other_factors = [ + inpt for i, inpt in enumerate(node.inputs) if i not in (dot22_idx, mul_idx) + ] + other_m_inputs = [ + inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx + ] + + return [mul(dot, *(other_factors + other_m_inputs))] + + scalar_idx = -1 + for i, x in enumerate(node.inputs): + if ( + i != dot22_idx + and i_scalar[i] is not None + and (pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype) + ): + scalar_idx = i + break + if scalar_idx < 0: + _logger.info( + f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the type " + "of the scalar cannot be upcasted to the matrix type" + ) + return False + assert scalar_idx < len(node.inputs) + s = node.inputs[scalar_idx] + o = copy.copy(node.inputs) + o.remove(d) + o.remove(s) + + a = at.cast(i_scalar[scalar_idx], d.type.dtype) + assert not a.type.ndim + if len(o) == 0: + return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)] + else: + return [mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)] + + +# must happen after gemm as the gemm optimizer don't understant +# dot22scalar and gemm give more speed up then dot22scalar +blas_optdb.register( + "local_dot22_to_dot22scalar", + in2out(local_dot22_to_dot22scalar), + "fast_run", + position=11, +) + + +# from opt import register_specialize, register_canonicalize +# @register_specialize +@node_rewriter([sub, add]) +def local_print_as_we_go_along(fgraph, node): + if node.op in (sub, add): + debugprint(node) diff --git a/pytensor/tensor/rewriting/blas_c.py b/pytensor/tensor/rewriting/blas_c.py new file mode 100644 index 0000000000..77629dccca --- /dev/null +++ b/pytensor/tensor/rewriting/blas_c.py @@ -0,0 +1,70 @@ +from pytensor.configdefaults import config +from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor import basic as at +from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive +from pytensor.tensor.blas_c import ( + CGemv, + CGer, + cgemv_inplace, + cgemv_no_inplace, + cger_inplace, +) +from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb + + +@node_rewriter([ger, ger_destructive]) +def use_c_ger(fgraph, node): + if not config.blas__ldflags: + return + # Only float32 and float64 are supported for now. + if node.op == ger and node.outputs[0].dtype in ("float32", "float64"): + return [CGer(False)(*node.inputs)] + if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"): + return [CGer(True)(*node.inputs)] + + +@node_rewriter([CGer(False)]) +def make_c_ger_destructive(fgraph, node): + if isinstance(node.op, CGer) and not node.op.destructive: + return [cger_inplace(*node.inputs)] + + +@node_rewriter([gemv_inplace, gemv_no_inplace]) +def use_c_gemv(fgraph, node): + if not config.blas__ldflags: + return + # Only float32 and float64 are supported for now. + if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"): + return [cgemv_no_inplace(*node.inputs)] + if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"): + return [cgemv_inplace(*node.inputs)] + + +@node_rewriter([CGemv(inplace=False)]) +def make_c_gemv_destructive(fgraph, node): + if isinstance(node.op, CGemv) and not node.op.inplace: + inputs = list(node.inputs) + dest = inputs[0] + if ( + dest.owner + and isinstance(dest.owner.op, at.AllocEmpty) + and len(fgraph.clients[dest]) > 1 + ): + inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs) + + return [cgemv_inplace(*inputs)] + + +blas_optdb.register( + "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 +) + +# this matches the InplaceBlasOpt defined in blas.py +optdb.register( + "c_blas_destructive", + in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), + "fast_run", + "inplace", + "c_blas", + position=70.0, +) diff --git a/pytensor/tensor/rewriting/blas_scipy.py b/pytensor/tensor/rewriting/blas_scipy.py new file mode 100644 index 0000000000..2b2aa94eef --- /dev/null +++ b/pytensor/tensor/rewriting/blas_scipy.py @@ -0,0 +1,37 @@ +from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor.blas import ger, ger_destructive, have_fblas +from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace +from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb + + +@node_rewriter([ger, ger_destructive]) +def use_scipy_ger(fgraph, node): + if node.op == ger: + return [scipy_ger_no_inplace(*node.inputs)] + + +@node_rewriter([scipy_ger_no_inplace]) +def make_ger_destructive(fgraph, node): + if node.op == scipy_ger_no_inplace: + return [scipy_ger_inplace(*node.inputs)] + + +use_scipy_blas = in2out(use_scipy_ger) +make_scipy_blas_destructive = in2out(make_ger_destructive) + +if have_fblas: + # scipy_blas is scheduled in the blas_optdb very late, because scipy sortof + # sucks, but it is almost always present. + # C implementations should be scheduled earlier than this, so that they take + # precedence. Once the original Ger is replaced, then these optimizations + # have no effect. + blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) + + # this matches the InplaceBlasOpt defined in blas.py + optdb.register( + "make_scipy_blas_destructive", + make_scipy_blas_destructive, + "fast_run", + "inplace", + position=70.0, + ) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 0ce7640d38..035f9e036b 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -44,12 +44,11 @@ gemv_no_inplace, ger, ger_destructive, - local_dot22_to_dot22scalar, - local_gemm_to_ger, res_is_a, ) from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt +from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger from pytensor.tensor.type import ( cmatrix, col, From 39051061f6760fd94abe9f34b070f40ad84acba8 Mon Sep 17 00:00:00 2001 From: David Horsley Date: Fri, 19 May 2023 18:58:36 +1000 Subject: [PATCH 3/4] Move linalg rewrites, delete sandbox Moving caused a circular dependency with tensor.blas. It seems most linalg rewrites are in the stablize set, so should run before the blas specializers anyway, so these checks were removed. This also deleted the unused `spectral_radius_bound` and dummy `Minimal(Op)`. --- pytensor/sandbox/__init__.py | 0 pytensor/sandbox/linalg/__init__.py | 1 - pytensor/sandbox/minimal.py | 46 ----------------- pytensor/tensor/rewriting/__init__.py | 1 + .../ops.py => tensor/rewriting/linalg.py} | 0 tests/sandbox/__init__.py | 0 tests/sandbox/linalg/__init__.py | 0 tests/sandbox/test_minimal.py | 32 ------------ .../rewriting}/test_linalg.py | 49 +------------------ 9 files changed, 2 insertions(+), 127 deletions(-) delete mode 100644 pytensor/sandbox/__init__.py delete mode 100644 pytensor/sandbox/linalg/__init__.py delete mode 100644 pytensor/sandbox/minimal.py rename pytensor/{sandbox/linalg/ops.py => tensor/rewriting/linalg.py} (100%) delete mode 100644 tests/sandbox/__init__.py delete mode 100644 tests/sandbox/linalg/__init__.py delete mode 100644 tests/sandbox/test_minimal.py rename tests/{sandbox/linalg => tensor/rewriting}/test_linalg.py (74%) diff --git a/pytensor/sandbox/__init__.py b/pytensor/sandbox/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pytensor/sandbox/linalg/__init__.py b/pytensor/sandbox/linalg/__init__.py deleted file mode 100644 index e4428ca21f..0000000000 --- a/pytensor/sandbox/linalg/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from pytensor.sandbox.linalg.ops import spectral_radius_bound diff --git a/pytensor/sandbox/minimal.py b/pytensor/sandbox/minimal.py deleted file mode 100644 index c0236e6cc7..0000000000 --- a/pytensor/sandbox/minimal.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np - -from pytensor.graph.basic import Apply -from pytensor.graph.op import Op -from pytensor.tensor.type import lscalar - - -class Minimal(Op): - # TODO : need description for class - - # if the Op has any attributes, consider using them in the eq function. - # If two Apply nodes have the same inputs and the ops compare equal... - # then they will be MERGED so they had better have computed the same thing! - - __props__ = () - - def __init__(self): - # If you put things here, think about whether they change the outputs - # computed by # self.perform() - # - If they do, then you should take them into consideration in - # __eq__ and __hash__ - # - If they do not, then you should not use them in - # __eq__ and __hash__ - - super().__init__() - - def make_node(self, *args): - # HERE `args` must be PYTENSOR VARIABLES - return Apply(op=self, inputs=args, outputs=[lscalar()]) - - def perform(self, node, inputs, out_): - (output,) = out_ - # HERE `inputs` are PYTHON OBJECTS - - # do what you want here, - # but do not modify any of the arguments [inplace]. - print("perform got %i arguments" % len(inputs)) - - print("Max of input[0] is ", np.max(inputs[0])) - - # return some computed value. - # do not return something that is aliased to one of the inputs. - output[0] = np.asarray(0, dtype="int64") - - -minimal = Minimal() diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index d8836e4b7b..80946d524c 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -7,6 +7,7 @@ # Register JAX specializations import pytensor.tensor.rewriting.jax +import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special diff --git a/pytensor/sandbox/linalg/ops.py b/pytensor/tensor/rewriting/linalg.py similarity index 100% rename from pytensor/sandbox/linalg/ops.py rename to pytensor/tensor/rewriting/linalg.py diff --git a/tests/sandbox/__init__.py b/tests/sandbox/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/sandbox/linalg/__init__.py b/tests/sandbox/linalg/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/sandbox/test_minimal.py b/tests/sandbox/test_minimal.py deleted file mode 100644 index 82e346eaf3..0000000000 --- a/tests/sandbox/test_minimal.py +++ /dev/null @@ -1,32 +0,0 @@ -import numpy as np -import pytest - -from pytensor import function -from pytensor.sandbox.minimal import minimal -from pytensor.tensor.type import matrix, vector -from tests import unittest_tools as utt - - -@pytest.mark.skip(reason="Unfinished test") -class TestMinimal: - """ - TODO: test dtype conversion - TODO: test that invalid types are rejected by make_node - TODO: test that each valid type for A and b works correctly - """ - - def setup_method(self): - self.rng = np.random.default_rng(utt.fetch_seed(666)) - - def test_minimal(self): - A = matrix() - b = vector() - - print("building function") - f = function([A, b], minimal(A, A, b, b, A)) - print("built") - - Aval = self.rng.standard_normal((5, 5)) - bval = np.arange(5, dtype=float) - f(Aval, bval) - print("done") diff --git a/tests/sandbox/linalg/test_linalg.py b/tests/tensor/rewriting/test_linalg.py similarity index 74% rename from tests/sandbox/linalg/test_linalg.py rename to tests/tensor/rewriting/test_linalg.py index f2cb67221c..673dd32f21 100644 --- a/tests/sandbox/linalg/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -5,10 +5,10 @@ from pytensor import function from pytensor import tensor as at from pytensor.configdefaults import config -from pytensor.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse +from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.slinalg import Cholesky, Solve, solve from pytensor.tensor.type import dmatrix, matrix, vector from tests import unittest_tools as utt @@ -65,53 +65,6 @@ def test_rop_lop(): assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}" -def test_spectral_radius_bound(): - tol = 10 ** (-6) - rng = np.random.default_rng(utt.fetch_seed()) - x = matrix() - radius_bound = spectral_radius_bound(x, 5) - f = pytensor.function([x], radius_bound) - - shp = (3, 4) - m = rng.random(shp) - m = np.cov(m).astype(config.floatX) - radius_bound_pytensor = f(m) - - # test the approximation - mm = m - for i in range(5): - mm = np.dot(mm, mm) - radius_bound_numpy = np.trace(mm) ** (2 ** (-5)) - assert abs(radius_bound_numpy - radius_bound_pytensor) < tol - - # test the bound - eigen_val = numpy.linalg.eig(m) - assert (eigen_val[0].max() - radius_bound_pytensor) < tol - - # test type errors - xx = vector() - ok = False - try: - spectral_radius_bound(xx, 5) - except TypeError: - ok = True - assert ok - ok = False - try: - spectral_radius_bound(x, 5.0) - except TypeError: - ok = True - assert ok - - # test value error - ok = False - try: - spectral_radius_bound(x, -5) - except ValueError: - ok = True - assert ok - - def test_transinv_to_invtrans(): X = matrix("X") Y = matrix_inverse(X) From 0bbf1a4cb5ca1224c8d3d79c67ce2fd4b595ed2c Mon Sep 17 00:00:00 2001 From: David Horsley Date: Tue, 16 May 2023 23:15:53 +1000 Subject: [PATCH 4/4] Add cholesky of L.LT rewrite --- pytensor/tensor/rewriting/linalg.py | 44 ++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 75 +++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 0a53924801..8f09e52261 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -109,6 +109,50 @@ def psd_solve_with_chol(fgraph, node): return [x] +@register_canonicalize +@register_stabilize +@node_rewriter([Cholesky]) +def cholesky_ldotlt(fgraph, node): + """ + rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, + or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular. + + This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. + """ + if not isinstance(node.op, Cholesky): + return + + A = node.inputs[0] + if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))): + return + + l, r = A.owner.inputs + + # cholesky(dot(L,L.T)) case + if ( + getattr(l.tag, "lower_triangular", False) + and r.owner + and isinstance(r.owner.op, DimShuffle) + and r.owner.op.new_order == (1, 0) + and r.owner.inputs[0] == l + ): + if node.op.lower: + return [l] + return [r] + + # cholesky(dot(U.T,U)) case + if ( + getattr(r.tag, "upper_triangular", False) + and l.owner + and isinstance(l.owner.op, DimShuffle) + and l.owner.op.new_order == (1, 0) + and l.owner.inputs[0] == r + ): + if node.op.lower: + return [l] + return [r] + + @register_stabilize @register_specialize @node_rewriter([Det]) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 673dd32f21..9ec182cb21 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1,9 +1,12 @@ import numpy as np import numpy.linalg +import pytest +import scipy.linalg import pytensor from pytensor import function from pytensor import tensor as at +from pytensor.compile import get_default_mode from pytensor.configdefaults import config from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose @@ -105,3 +108,75 @@ def test_matrix_inverse_solve(): node = matrix_inverse(A).dot(b).owner [out] = inv_as_solve.transform(None, node) assert isinstance(out.owner.op, Solve) + + +@pytest.mark.parametrize("tag", ("lower", "upper", None)) +@pytest.mark.parametrize("cholesky_form", ("lower", "upper")) +@pytest.mark.parametrize("product", ("lower", "upper", None)) +def test_cholesky_ldotlt(tag, cholesky_form, product): + cholesky = Cholesky(lower=(cholesky_form == "lower")) + + transform_removes_chol = tag is not None and product == tag + transform_transposes = transform_removes_chol and cholesky_form != tag + + A = matrix("L") + if tag: + setattr(A.tag, tag + "_triangular", True) + + if product == "lower": + M = A.dot(A.T) + elif product == "upper": + M = A.T.dot(A) + else: + M = A + + C = cholesky(M) + f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) + + print(f.maker.fgraph.apply_nodes) + + no_cholesky_in_graph = not any( + isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes + ) + + assert no_cholesky_in_graph == transform_removes_chol + + if transform_transposes: + assert any( + isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0) + for node in f.maker.fgraph.apply_nodes + ) + + # Test some concrete value through f + # there must be lower triangular (f assumes they are) + Avs = [ + np.eye(1, dtype=pytensor.config.floatX), + np.eye(10, dtype=pytensor.config.floatX), + np.array([[2, 0], [1, 4]], dtype=pytensor.config.floatX), + ] + if not tag: + # these must be positive def + Avs.extend( + [ + np.ones((4, 4), dtype=pytensor.config.floatX) + + np.eye(4, dtype=pytensor.config.floatX), + ] + ) + + for Av in Avs: + if tag == "upper": + Av = Av.T + + if product == "lower": + Mv = Av.dot(Av.T) + elif product == "upper": + Mv = Av.T.dot(Av) + else: + Mv = Av + + assert np.all( + np.isclose( + scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")), + f(Av), + ) + )