53
53
import distutils .ccompiler
54
54
import distutils .command .clean
55
55
import os
56
+ import re
56
57
import requests
57
58
import shutil
58
59
import subprocess
@@ -226,7 +227,7 @@ class BuildBazelExtension(build_ext.build_ext):
226
227
def run (self ):
227
228
for ext in self .extensions :
228
229
self .bazel_build (ext )
229
- command .build_ext .build_ext .run (self )
230
+ command .build_ext .build_ext .run (self ) # type: ignore
230
231
231
232
def bazel_build (self , ext ):
232
233
if not os .path .exists (self .build_temp ):
@@ -260,17 +261,107 @@ def bazel_build(self, ext):
260
261
shutil .copyfile (ext_bazel_bin_path , ext_dest_path )
261
262
262
263
264
+ # Read in README.md for our long_description
265
+ cwd = os .path .dirname (os .path .abspath (__file__ ))
266
+ with open (os .path .join (cwd , "README.md" ), encoding = "utf-8" ) as f :
267
+ long_description = f .read ()
268
+
269
+ # Finds torch_xla and its subpackages
270
+ packages_to_include = find_packages (include = ['torch_xla*' ])
271
+ # Explicitly add torchax
272
+ packages_to_include .extend (find_packages (where = 'torchax' , include = ['torchax*' ]))
273
+
274
+ # Map the top-level 'torchax' package name to its source location
275
+ torchax_dir = os .path .join (cwd , 'torchax' )
276
+ package_dir_mapping = {'torch_xla' : os .path .join (cwd , 'torch_xla' )}
277
+ package_dir_mapping ['torchax' ] = os .path .join (torchax_dir , 'torchax' )
278
+
279
+
263
280
class Develop (develop .develop ):
264
281
265
282
def run (self ):
283
+ # Build the C++ extension
266
284
self .run_command ("build_ext" )
285
+
286
+ # Run the standard develop process first
287
+ # This installs dependencies, scripts, and importantly, creates an `.egg-link` file
267
288
super ().run ()
268
289
290
+ # Replace the `.egg-link` with a `.pth` file.
291
+ self .link_packages ()
292
+
293
+ def link_packages (self ):
294
+ """
295
+ There are two mechanisms to install an "editable" package in Python: `.egg-link`
296
+ and `.pth` files. setuptools uses `.egg-link` by default. However, `.egg-link`
297
+ only supports linking a single directory containg one editable package.
298
+ This function removes the `.egg-link` file and generates a `.pth` file that can
299
+ be used to link multiple packages, in particular, `torch_xla` and `torchax`.
300
+
301
+ Note that this function is only relevant in the editable package development path
302
+ (`python setup.py develop`). Nightly and release wheel builds work out of the box
303
+ without egg-link/pth.
304
+ """
305
+ # Ensure paths like self.install_dir are set
306
+ self .ensure_finalized ()
307
+
308
+ # Get the site-packages directory
309
+ target_dir = self .install_dir
310
+
311
+ # Remove the standard .egg-link file
312
+ # It's usually named based on the distribution name
313
+ dist_name = self .distribution .get_name ()
314
+ egg_link_file = os .path .join (target_dir , dist_name + '.egg-link' )
315
+ if os .path .exists (egg_link_file ):
316
+ print (f"Removing default egg-link file: { egg_link_file } " )
317
+ try :
318
+ os .remove (egg_link_file )
319
+ except OSError as e :
320
+ print (f"Warning: Could not remove { egg_link_file } : { e } " )
321
+
322
+ # Create our custom .pth file with specific paths
323
+ cwd = os .path .dirname (__file__ )
324
+ # Path containing 'torch_xla' package source: ROOT
325
+ path_for_torch_xla = os .path .abspath (cwd )
326
+ # Path containing 'torchax' package source: ROOT/torchax
327
+ path_for_torchax = os .path .abspath (os .path .join (cwd , 'torchax' ))
328
+
329
+ paths_to_add = {path_for_torch_xla , path_for_torchax }
330
+
331
+ # Construct a suitable .pth filename (PEP 660 style is good practice)
332
+ version = self .distribution .get_version ()
333
+ # Sanitize name and version for filename (replace runs of non-alphanumeric chars with '-')
334
+ sanitized_name = re .sub (r"[^a-zA-Z0-9.]+" , "_" , dist_name )
335
+ sanitized_version = re .sub (r"[^a-zA-Z0-9.]+" , "_" , version )
336
+ pth_filename = os .path .join (
337
+ target_dir , f"__editable_{ sanitized_name } _{ sanitized_version } .pth" )
338
+
339
+ # Ensure site-packages exists
340
+ os .makedirs (target_dir , exist_ok = True )
341
+
342
+ # Write the paths to the .pth file, one per line
343
+ with open (pth_filename , "w" , encoding = 'utf-8' ) as f :
344
+ for path in sorted (paths_to_add ):
345
+ f .write (path + "\n " )
346
+
347
+
348
+ def _get_jax_install_requirements ():
349
+ if not USE_NIGHTLY :
350
+ # Stable versions of JAX can be directly installed from PyPI.
351
+ return [
352
+ f'jaxlib=={ _jaxlib_version } ' ,
353
+ f'jax=={ _jax_version } ' ,
354
+ ]
355
+
356
+ # Install nightly JAX libraries from the JAX package registries.
357
+ jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{ _jax_version } -py3-none-any.whl'
358
+ jaxlib = []
359
+ for python_minor_version in [9 , 10 , 11 ]:
360
+ jaxlib .append (
361
+ f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{ _jaxlib_version } -cp3{ python_minor_version } -cp3{ python_minor_version } -manylinux2014_x86_64.whl ; python_version == "3.{ python_minor_version } "'
362
+ )
363
+ return [jax ] + jaxlib
269
364
270
- # Read in README.md for our long_description
271
- cwd = os .path .dirname (os .path .abspath (__file__ ))
272
- with open (os .path .join (cwd , "README.md" ), encoding = "utf-8" ) as f :
273
- long_description = f .read ()
274
365
275
366
setup (
276
367
name = os .environ .get ('TORCH_XLA_PACKAGE_NAME' , 'torch_xla' ),
@@ -297,7 +388,8 @@ def run(self):
297
388
"Programming Language :: Python :: 3" ,
298
389
],
299
390
python_requires = ">=3.8.0" ,
300
- packages = find_packages (include = ['torch_xla*' ]),
391
+ packages = packages_to_include ,
392
+ package_dir = package_dir_mapping ,
301
393
ext_modules = [
302
394
BazelExtension ('//:_XLAC.so' ),
303
395
BazelExtension ('//:_XLAC_cuda_functions.so' ),
@@ -310,6 +402,8 @@ def run(self):
310
402
# importlib.metadata backport required for PJRT plugin discovery prior
311
403
# to Python 3.10
312
404
'importlib_metadata>=4.6;python_version<"3.10"' ,
405
+ # Some torch operations are lowered to HLO via JAX.
406
+ * _get_jax_install_requirements (),
313
407
],
314
408
package_data = {
315
409
'torch_xla' : ['lib/*.so*' ,],
@@ -331,6 +425,8 @@ def run(self):
331
425
f'libtpu=={ _libtpu_version } ' ,
332
426
'tpu-info' ,
333
427
],
428
+ # As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla.
429
+ # However, this no-op extras_require entrypoint is left here for backwards compatibility.
334
430
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
335
431
'pallas' : [f'jaxlib=={ _jaxlib_version } ' , f'jax=={ _jax_version } ' ],
336
432
},
0 commit comments