Skip to content

Commit ae3c72d

Browse files
committed
Add Accelerate framework blas__ldflags tests
1 parent 6132203 commit ae3c72d

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

pytensor/link/c/cmodule.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,9 +2873,21 @@ def check_libs(
28732873
)
28742874
except Exception as e:
28752875
_logger.debug(e)
2876+
try:
2877+
# 3. Mac Accelerate framework
2878+
_logger.debug("Checking Accelerate framework")
2879+
flags = ["-framework", "Accelerate"]
2880+
if rpath:
2881+
flags = [*flags, "-rpath", rpath]
2882+
validated_flags = try_blas_flag(flags)
2883+
if validated_flags == "":
2884+
raise Exception("Accelerate framework flag failed ")
2885+
return validated_flags
2886+
except Exception as e:
2887+
_logger.debug(e)
28762888
try:
28772889
_logger.debug("Checking Lapack + blas")
2878-
# 3. Try to use LAPACK + BLAS
2890+
# 4. Try to use LAPACK + BLAS
28792891
return check_libs(
28802892
all_libs,
28812893
required_libs=["lapack", "blas", "cblas", "m"],
@@ -2885,7 +2897,7 @@ def check_libs(
28852897
except Exception as e:
28862898
_logger.debug(e)
28872899
try:
2888-
# 4. Try to use BLAS alone
2900+
# 5. Try to use BLAS alone
28892901
_logger.debug("Checking blas alone")
28902902
return check_libs(
28912903
all_libs,
@@ -2896,7 +2908,7 @@ def check_libs(
28962908
except Exception as e:
28972909
_logger.debug(e)
28982910
try:
2899-
# 5. Try to use openblas
2911+
# 6. Try to use openblas
29002912
_logger.debug("Checking openblas")
29012913
return check_libs(
29022914
all_libs,

pytensor/tensor/blas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import logging
8080
import os
8181
import time
82+
from pathlib import Path
8283

8384
import numpy as np
8485

@@ -425,7 +426,7 @@ def _ldflags(
425426

426427
try:
427428
t0, t1 = t[0], t[1]
428-
assert t0 == "-"
429+
assert t0 == "-" or t == "Accelerate" or Path(t).exists()
429430
except Exception:
430431
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
431432
if libs_dir and t1 == "L":

0 commit comments

Comments
 (0)