Skip to content

[UR] Replace loader handles with field at start of handle data #17118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions unified-runtime/scripts/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,25 +223,6 @@ def _mako_loader_cpp(path, namespace, tags, version, specs, meta):
"make_loader_cpp path %s namespace %s version %s\n" % (path, namespace, version)
)
loc = 0
template = "ldrddi.hpp.mako"
fin = os.path.join(templates_dir, template)

name = "%s_ldrddi" % (namespace)
filename = "%s.hpp" % (name)
fout = os.path.join(path, filename)

print("Generating %s..." % fout)
loc += util.makoWrite(
fin,
fout,
name=name,
ver=version,
namespace=namespace,
tags=tags,
specs=specs,
meta=meta,
)

template = "ldrddi.cpp.mako"
fin = os.path.join(templates_dir, template)

Expand Down
220 changes: 31 additions & 189 deletions unified-runtime/scripts/templates/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def get_adapter_manifests(specs):
objs.append(obj)
return objs


"""
Public:
returns a list of all loader API functions' names
Expand Down Expand Up @@ -1510,39 +1511,6 @@ def get_initial_null_set(obj):
return ""


"""
Public:
returns true if the function always wraps output pointers in loader handles
"""


def always_wrap_outputs(obj):
cname = obj_traits.class_name(obj)
return (cname, obj["name"]) in [
("$xProgram", "Link"),
("$xProgram", "LinkExp"),
]


"""
Private:
returns the list of parameters, filtering based on desc tags
"""


def _filter_param_list(params, filters1=["[in]", "[in,out]", "[out]"], filters2=[""]):
lst = []
for p in params:
for f1 in filters1:
if f1 in p["desc"]:
for f2 in filters2:
if f2 in p["desc"]:
lst.append(p)
break
break
return lst


"""
Public:
returns a list of dict of each pfntables needed
Expand All @@ -1560,131 +1528,6 @@ def get_pfncbtables(specs, meta, namespace, tags):
return tables


"""
Public:
returns a list of dict for converting loader input parameters
"""


def get_loader_prologue(namespace, tags, obj, meta):
prologue = []

params = _filter_param_list(obj["params"], ["[in]"])
for item in params:
if param_traits.is_mbz(item):
continue
if type_traits.is_class_handle(item["type"], meta):
name = subt(namespace, tags, item["name"])
tname = _remove_const_ptr(subt(namespace, tags, item["type"]))

# e.g., "xe_device_handle_t" -> "xe_device_object_t"
obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname)
fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname)

if type_traits.is_pointer(item["type"]):
range_start = param_traits.range_start(item)
range_end = param_traits.range_end(item)
prologue.append(
{
"name": name,
"obj": obj_name,
"range": (range_start, range_end),
"type": tname,
"factory": fty_name,
"pointer": "*",
}
)
else:
prologue.append(
{
"name": name,
"obj": obj_name,
"optional": param_traits.is_optional(item),
"pointer": "",
}
)

return prologue


"""
Private:
Takes a list of struct members and recursively searches for class handles.
Returns a list of class handles with access chains to reach them (e.g.
"struct_a->struct_b.handle"). Also handles ranges of class handles and
ranges of structs with class handle members, although the latter only works
to one level of recursion i.e. a range of structs with a range of structs
with a handle member will not work.
"""


def get_struct_handle_members(
namespace, tags, meta, members, parent="", is_struct_range=False
):
handle_members = []
for m in members:
if type_traits.is_class_handle(m["type"], meta):
m_tname = _remove_const_ptr(subt(namespace, tags, m["type"]))
m_objname = re.sub(r"(\w+)_handle_t", r"\1_object_t", m_tname)
# We can deal with a range of handles, but not if it's in a range of structs
if param_traits.is_range(m) and not is_struct_range:
handle_members.append(
{
"parent": parent,
"name": m["name"],
"obj_name": m_objname,
"type": m_tname,
"range_start": param_traits.range_start(m),
"range_end": param_traits.range_end(m),
}
)
else:
handle_members.append(
{
"parent": parent,
"name": m["name"],
"obj_name": m_objname,
"optional": param_traits.is_optional(m),
}
)
elif type_traits.is_struct(m["type"], meta):
member_struct_members = type_traits.get_struct_members(m["type"], meta)
if param_traits.is_range(m):
# If we've hit a range of structs we need to start a new recursion looking
# for handle members. We do not support range within range, so skip that
if is_struct_range:
continue
range_handle_members = get_struct_handle_members(
namespace, tags, meta, member_struct_members, "", True
)
if range_handle_members:
handle_members.append(
{
"parent": parent,
"name": m["name"],
"type": subt(namespace, tags, _remove_const_ptr(m["type"])),
"range_start": param_traits.range_start(m),
"range_end": param_traits.range_end(m),
"handle_members": range_handle_members,
}
)
else:
# If it's just a struct we can keep recursing in search of handles
m_is_pointer = type_traits.is_pointer(m["type"])
new_parent_deref = "->" if m_is_pointer else "."
new_parent = m["name"] + new_parent_deref
handle_members += get_struct_handle_members(
namespace,
tags,
meta,
member_struct_members,
new_parent,
is_struct_range,
)

return handle_members


"""
Public:
Strips a string of all dereferences.
Expand All @@ -1702,37 +1545,6 @@ def strip_deref(string_to_strip):
return string_to_strip.replace("->", "")


"""
Public:
Takes a function object and recurses through its struct parameters to return
a list of structs that have handle object members the loader will need to
convert.
"""


def get_object_handle_structs_to_convert(namespace, tags, obj, meta):
structs = []
params = _filter_param_list(obj["params"], ["[in]"])

for item in params:
if type_traits.is_struct(item["type"], meta):
members = type_traits.get_struct_members(item["type"], meta)
handle_members = get_struct_handle_members(namespace, tags, meta, members)
if handle_members:
name = subt(namespace, tags, item["name"])
tname = _remove_const_ptr(subt(namespace, tags, item["type"]))
struct = {
"name": name,
"type": tname,
"optional": param_traits.is_optional(item),
"members": handle_members,
}

structs.append(struct)

return structs


"""
Public:
returns an enum object with the given name
Expand Down Expand Up @@ -2039,3 +1851,33 @@ def get_etors(obj):
if etor_traits.is_deprecated_etor(item):
continue
yield item


"""
Public:
Returns the first non-optional non-native handle for the given function.

If it is a range, `name[0]` will be returned instead of `name`.
"""


def get_dditable_field(obj):
for p in obj["params"]:
if param_traits.is_optional(p):
continue
if "native_handle_t" in p["type"]:
continue

if param_traits.is_range(p):
if not p["type"].endswith("_handle_t*"):
continue
return p["name"] + "[0]"
else:
if not p["type"].endswith("_handle_t"):
continue
return p["name"]
obj_class = obj["class"]
name = obj["name"]
raise RuntimeError(
f"Function {obj_class}::{name} does not have a non-optional handle argument"
)
Loading
Loading