Skip to content

Define CpuDeviceInterface #636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ extern "C" {
namespace facebook::torchcodec {
namespace {

bool g_cpu = registerDeviceInterface(
static bool g_cpu = registerDeviceInterface(
torch::kCPU,
[](const torch::Device& device) { return new CpuDeviceInterface(device); });

Expand All @@ -36,6 +36,7 @@ bool CpuDeviceInterface::DecodedFrameContext::operator!=(

CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
: DeviceInterface(device) {
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
if (device_.type() != torch::kCPU) {
throw std::runtime_error("Unsupported device: " + device_.str());
}
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ extern "C" {
namespace facebook::torchcodec {
namespace {

bool g_cuda =
static bool g_cuda =
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
return new CudaDeviceInterface(device);
});
Expand Down Expand Up @@ -165,6 +165,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) {

CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
: DeviceInterface(device) {
TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!");
if (device_.type() != torch::kCUDA) {
throw std::runtime_error("Unsupported device: " + device_.str());
}
Expand Down
39 changes: 20 additions & 19 deletions src/torchcodec/_core/DeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ namespace facebook::torchcodec {

namespace {
using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
std::mutex g_interface_mutex;
std::unique_ptr<DeviceInterfaceMap> g_interface_map;
static std::mutex g_interface_mutex;

DeviceInterfaceMap& getDeviceMap() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thank you for the fix, @scotts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curiously, this isn't the whole problem, but I suspect it's part of the problem. I also think we should make sure that we have a use of the boolean used to capture the result of registering the device. But, again, that hasn't solved all of the problems on my end, but I'm confident it's better than what we had before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@scotts, just to conclude the discussion. I see you've merged. Were the other changes needed on internal code base or the adjustments you've done actually fix the issues?

static DeviceInterfaceMap deviceMap;
return deviceMap;
}

std::string getDeviceType(const std::string& device) {
size_t pos = device.find(':');
Expand All @@ -29,35 +33,31 @@ bool registerDeviceInterface(
torch::DeviceType deviceType,
CreateDeviceInterfaceFn createInterface) {
std::scoped_lock lock(g_interface_mutex);
if (!g_interface_map) {
// We delay this initialization until runtime to avoid the Static
// Initialization Order Fiasco:
//
// https://en.cppreference.com/w/cpp/language/siof
g_interface_map = std::make_unique<DeviceInterfaceMap>();
}
DeviceInterfaceMap& deviceMap = getDeviceMap();

TORCH_CHECK(
g_interface_map->find(deviceType) == g_interface_map->end(),
deviceMap.find(deviceType) == deviceMap.end(),
"Device interface already registered for ",
deviceType);
g_interface_map->insert({deviceType, createInterface});
deviceMap.insert({deviceType, createInterface});

return true;
}

torch::Device createTorchDevice(const std::string device) {
std::scoped_lock lock(g_interface_mutex);
std::string deviceType = getDeviceType(device);
DeviceInterfaceMap& deviceMap = getDeviceMap();

auto deviceInterface = std::find_if(
g_interface_map->begin(),
g_interface_map->end(),
deviceMap.begin(),
deviceMap.end(),
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
return device.rfind(
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
});
TORCH_CHECK(
deviceInterface != g_interface_map->end(),
"Unsupported device: ",
device);
deviceInterface != deviceMap.end(), "Unsupported device: ", device);

return torch::Device(device);
}
Expand All @@ -66,13 +66,14 @@ std::unique_ptr<DeviceInterface> createDeviceInterface(
const torch::Device& device) {
auto deviceType = device.type();
std::scoped_lock lock(g_interface_mutex);
DeviceInterfaceMap& deviceMap = getDeviceMap();

TORCH_CHECK(
g_interface_map->find(deviceType) != g_interface_map->end(),
deviceMap.find(deviceType) != deviceMap.end(),
"Unsupported device: ",
device);

return std::unique_ptr<DeviceInterface>(
(*g_interface_map)[deviceType](device));
return std::unique_ptr<DeviceInterface>(deviceMap[deviceType](device));
}

} // namespace facebook::torchcodec
Loading