|
16 | 16 | # under the License.
|
17 | 17 | # pylint: disable=invalid-name
|
18 | 18 | """Internal DiscoWorker for Disco ProcessSession."""
|
| 19 | + |
19 | 20 | import os
|
20 | 21 | import sys
|
21 | 22 |
|
|
31 | 32 |
|
32 | 33 | def main():
|
33 | 34 | """Main worker function"""
|
34 |
| - if len(sys.argv) != 5: |
35 |
| - print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>") |
| 35 | + |
| 36 | + if len(sys.argv) == 5 or len(sys.argv) == 6: |
| 37 | + *args, read_fd, write_fd = map(int, sys.argv[1:]) |
| 38 | + else: |
| 39 | + print( |
| 40 | + f"Expected exactly either 4 or 5 arguments, " |
| 41 | + f"but received {len(sys.argv)-1} arguments.: {sys.argv}" |
| 42 | + ) |
| 43 | + # The <num_groups> argument was added in |
| 44 | + # https://github.com/apache/tvm/pull/17180. This script |
| 45 | + # currently checks the number of arguments present, to |
| 46 | + # determine whether `num_groups` was provided. This allows |
| 47 | + # the worker.py script provided by MLC-LLM to be compatible |
| 48 | + # with either pre-17180 or post-17180 arguments. |
| 49 | + # |
| 50 | + # After the TVM version used by MLC-LLM includes #17180, the |
| 51 | + # usage can be updated to always require `len(sys.argv)==6`. |
| 52 | + print("Usage (without num groups): <worker_id> <num_workers> <read_fd> <write_fd>") |
| 53 | + print( |
| 54 | + "Usage (with num groups): <worker_id> <num_workers> <num_groups> <read_fd> <write_fd>" |
| 55 | + ) |
36 | 56 | return
|
37 | 57 |
|
38 |
| - worker_id = int(sys.argv[1]) |
39 |
| - num_workers = int(sys.argv[2]) |
40 | 58 | if sys.platform == "win32":
|
41 | 59 | import msvcrt # pylint: disable=import-outside-toplevel,import-error
|
42 | 60 |
|
43 |
| - reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) |
44 |
| - writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) |
| 61 | + reader = msvcrt.open_osfhandle(read_fd, os.O_BINARY) |
| 62 | + writer = msvcrt.open_osfhandle(write_fd, os.O_BINARY) |
45 | 63 | else:
|
46 |
| - reader = int(sys.argv[3]) |
47 |
| - writer = int(sys.argv[4]) |
| 64 | + reader = read_fd |
| 65 | + writer = write_fd |
48 | 66 |
|
49 | 67 | worker_func = get_global_func("runtime.disco.WorkerProcess")
|
50 |
| - worker_func(worker_id, num_workers, reader, writer) |
| 68 | + worker_func(*args, reader, writer) |
51 | 69 |
|
52 | 70 |
|
53 | 71 | if __name__ == "__main__":
|
|
0 commit comments