Skip to content

Commit 6aeda51

Browse files
authored
Update worker.py for compatibility with upstream TVM (#275)
This commit updates `mlc_llm.cli.worker` to be compatible with upstream TVM apache/tvm#17180, which adds a `num_groups` argument to the disco worker function. To de-couple this compatibility from a general TVM version bump, this commit has a check on the number of `worker.py` arguments provided, to determine whether the `num_groups` argument is present. After the TVM version used by MLC-LLM is updated to include the upstream changes, this check can be removed.
1 parent fb6ec41 commit 6aeda51

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

python/mlc_llm/cli/worker.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""Internal DiscoWorker for Disco ProcessSession."""
19+
1920
import os
2021
import sys
2122

@@ -31,23 +32,40 @@
3132

3233
def main():
3334
"""Main worker function"""
34-
if len(sys.argv) != 5:
35-
print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
35+
36+
if len(sys.argv) == 5 or len(sys.argv) == 6:
37+
*args, read_fd, write_fd = map(int, sys.argv[1:])
38+
else:
39+
print(
40+
f"Expected exactly either 4 or 5 arguments, "
41+
f"but received {len(sys.argv)-1} arguments.: {sys.argv}"
42+
)
43+
# The <num_groups> argument was added in
44+
# https://github.com/apache/tvm/pull/17180. This script
45+
# currently checks the number of arguments present, to
46+
# determine whether `num_groups` was provided. This allows
47+
# the worker.py script provided by MLC-LLM to be compatible
48+
# with either pre-17180 or post-17180 arguments.
49+
#
50+
# After the TVM version used by MLC-LLM includes #17180, the
51+
# usage can be updated to always require `len(sys.argv)==6`.
52+
print("Usage (without num groups): <worker_id> <num_workers> <read_fd> <write_fd>")
53+
print(
54+
"Usage (with num groups): <worker_id> <num_workers> <num_groups> <read_fd> <write_fd>"
55+
)
3656
return
3757

38-
worker_id = int(sys.argv[1])
39-
num_workers = int(sys.argv[2])
4058
if sys.platform == "win32":
4159
import msvcrt # pylint: disable=import-outside-toplevel,import-error
4260

43-
reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
44-
writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
61+
reader = msvcrt.open_osfhandle(read_fd, os.O_BINARY)
62+
writer = msvcrt.open_osfhandle(write_fd, os.O_BINARY)
4563
else:
46-
reader = int(sys.argv[3])
47-
writer = int(sys.argv[4])
64+
reader = read_fd
65+
writer = write_fd
4866

4967
worker_func = get_global_func("runtime.disco.WorkerProcess")
50-
worker_func(worker_id, num_workers, reader, writer)
68+
worker_func(*args, reader, writer)
5169

5270

5371
if __name__ == "__main__":

0 commit comments

Comments
 (0)