Skip to content

Commit 8ad7493

Browse files
committed
[no-relnote] Refactor driver library discovery
This change aligns the driver file discovery with device discovery and allows other sources such as nvsandboxutils to be added. Signed-off-by: Evan Lezar <[email protected]>
1 parent 108b8c5 commit 8ad7493

File tree

7 files changed

+182
-49
lines changed

7 files changed

+182
-49
lines changed

internal/nvsandboxutils/gen/nvsandboxutils/nvsandboxutils.yml

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ TRANSLATOR:
4949
const:
5050
- {action: accept, from: "^NVSANDBOXUTILS_"}
5151
- {action: accept, from: "^nvSandboxUtils"}
52+
- {action: replace, from: "^NVSANDBOXUTILS_255_MASK_", to: "MASK255_" }
5253
- {action: replace, from: "^NVSANDBOXUTILS_"}
5354
- {action: replace, from: "^nvSandboxUtils"}
5455
- {action: accept, from: "^NV"}

pkg/nvcdi/driver-nvml.go renamed to internal/platform-support/dgpu/driver-nvml.go

+30-46
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
**/
1616

17-
package nvcdi
17+
package dgpu
1818

1919
import (
2020
"fmt"
@@ -31,66 +31,50 @@ import (
3131
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
3232
)
3333

34-
// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
35-
// The supplied NVML Library is used to query the expected driver version.
36-
func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string, ldconfigPath string, version string) (discover.Discover, error) {
37-
return newDriverVersionDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
38-
}
39-
40-
func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
41-
libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
34+
// newNvmlDriverDiscoverer constructs a discoverer from the specified NVML library.
35+
func (o *options) newNvmlDriverDiscoverer() (discover.Discover, error) {
36+
libraries, err := o.newNvmlDriverLibraryDiscoverer()
4237
if err != nil {
4338
return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err)
4439
}
4540

46-
ipcs, err := discover.NewIPCDiscoverer(logger, driver.Root)
47-
if err != nil {
48-
return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err)
49-
}
50-
51-
firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version)
41+
firmwares, err := o.newNvmlDriverFirmwareDiscoverer()
5242
if err != nil {
5343
return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err)
5444
}
5545

56-
binaries := NewDriverBinariesDiscoverer(logger, driver.Root)
46+
binaries := o.newNvmlDriverBinariesDiscoverer()
5747

5848
d := discover.Merge(
5949
libraries,
60-
ipcs,
6150
firmwares,
6251
binaries,
6352
)
6453

6554
return d, nil
6655
}
6756

68-
// NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version.
69-
func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
70-
libraryPaths, err := getVersionLibs(logger, driver, version)
57+
// newNvmlDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version.
58+
func (o *options) newNvmlDriverLibraryDiscoverer() (discover.Discover, error) {
59+
libraryPaths, err := getVersionLibs(o.logger, o.driver, o.version)
7160
if err != nil {
7261
return nil, fmt.Errorf("failed to get libraries for driver version: %v", err)
7362
}
7463

7564
libraries := discover.NewMounts(
76-
logger,
65+
o.logger,
7766
lookup.NewFileLocator(
78-
lookup.WithLogger(logger),
79-
lookup.WithRoot(driver.Root),
67+
lookup.WithLogger(o.logger),
68+
lookup.WithRoot(o.driver.Root),
8069
),
81-
driver.Root,
70+
o.driver.Root,
8271
libraryPaths,
8372
)
8473

85-
updateLDCache, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath)
86-
87-
d := discover.Merge(
88-
discover.WithDriverDotSoSymlinks(
89-
libraries,
90-
version,
91-
nvidiaCDIHookPath,
92-
),
93-
updateLDCache,
74+
d := discover.WithDriverDotSoSymlinks(
75+
libraries,
76+
o.version,
77+
o.nvidiaCDIHookPath,
9478
)
9579

9680
return d, nil
@@ -138,31 +122,31 @@ func getCustomFirmwareClassPath(logger logger.Interface) string {
138122
return strings.TrimSpace(string(customFirmwareClassPath))
139123
}
140124

141-
// NewDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version.
142-
func NewDriverFirmwareDiscoverer(logger logger.Interface, driverRoot string, version string) (discover.Discover, error) {
143-
gspFirmwareSearchPaths, err := getFirmwareSearchPaths(logger)
125+
// newNvmlDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version.
126+
func (o *options) newNvmlDriverFirmwareDiscoverer() (discover.Discover, error) {
127+
gspFirmwareSearchPaths, err := getFirmwareSearchPaths(o.logger)
144128
if err != nil {
145129
return nil, fmt.Errorf("failed to get firmware search paths: %v", err)
146130
}
147-
gspFirmwarePaths := filepath.Join("nvidia", version, "gsp*.bin")
131+
gspFirmwarePaths := filepath.Join("nvidia", o.version, "gsp*.bin")
148132
return discover.NewMounts(
149-
logger,
133+
o.logger,
150134
lookup.NewFileLocator(
151-
lookup.WithLogger(logger),
152-
lookup.WithRoot(driverRoot),
135+
lookup.WithLogger(o.logger),
136+
lookup.WithRoot(o.driver.Root),
153137
lookup.WithSearchPaths(gspFirmwareSearchPaths...),
154138
),
155-
driverRoot,
139+
o.driver.Root,
156140
[]string{gspFirmwarePaths},
157141
), nil
158142
}
159143

160-
// NewDriverBinariesDiscoverer creates a discoverer for GSP firmware associated with the GPU driver.
161-
func NewDriverBinariesDiscoverer(logger logger.Interface, driverRoot string) discover.Discover {
144+
// newNvmlDriverBinariesDiscoverer creates a discoverer for binaries associated with the specified driver version.
145+
func (o *options) newNvmlDriverBinariesDiscoverer() discover.Discover {
162146
return discover.NewMounts(
163-
logger,
164-
lookup.NewExecutableLocator(logger, driverRoot),
165-
driverRoot,
147+
o.logger,
148+
lookup.NewExecutableLocator(o.logger, o.driver.Root),
149+
o.driver.Root,
166150
[]string{
167151
"nvidia-smi", /* System management interface */
168152
"nvidia-debugdump", /* GPU coredump utility */
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package dgpu
18+
19+
import (
20+
"fmt"
21+
22+
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
23+
)
24+
25+
// newNvsandboxutilsDriverDiscoverer constructs a discoverer from the specified nvsandboxutils library.
26+
func (o *options) newNvsandboxutilsDriverDiscoverer() (discover.Discover, error) {
27+
if o.nvsandboxutilslib == nil {
28+
return nil, nil
29+
}
30+
return nil, fmt.Errorf("nvsandboxutils driver discovery is not implemented")
31+
}
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/**
2+
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package dgpu
18+
19+
import (
20+
"errors"
21+
"fmt"
22+
23+
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
24+
)
25+
26+
// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
27+
func NewDriverDiscoverer(opts ...Option) (discover.Discover, error) {
28+
o := new(opts...)
29+
30+
if o.version == "" {
31+
return nil, fmt.Errorf("a version must be specified")
32+
}
33+
34+
var discoverers []discover.Discover
35+
var errs error
36+
37+
nvsandboxutilsDiscoverer, err := o.newNvsandboxutilsDriverDiscoverer()
38+
if err != nil {
39+
// TODO: Log a warning
40+
errs = errors.Join(errs, err)
41+
} else if nvsandboxutilsDiscoverer != nil {
42+
discoverers = append(discoverers, nvsandboxutilsDiscoverer)
43+
}
44+
45+
nvmlDiscoverer, err := o.newNvmlDriverDiscoverer()
46+
if err != nil {
47+
// TODO: Log a warning
48+
errs = errors.Join(errs, err)
49+
} else if nvmlDiscoverer != nil {
50+
discoverers = append(discoverers, nvmlDiscoverer)
51+
}
52+
53+
if len(discoverers) == 0 {
54+
return nil, errs
55+
}
56+
57+
cached := discover.WithCache(
58+
discover.FirstValid(
59+
discoverers...,
60+
),
61+
)
62+
updateLDCache, _ := discover.NewLDCacheUpdateHook(o.logger, cached, o.nvidiaCDIHookPath, o.ldconfigPath)
63+
64+
ipcs, err := discover.NewIPCDiscoverer(o.logger, o.driver.Root)
65+
if err != nil {
66+
return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err)
67+
}
68+
69+
return discover.Merge(
70+
cached,
71+
updateLDCache,
72+
ipcs,
73+
), nil
74+
}

internal/platform-support/dgpu/options.go

+25
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@ package dgpu
1818

1919
import (
2020
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
21+
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
2122
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
2223
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
2324
)
2425

2526
type options struct {
2627
logger logger.Interface
28+
driver *root.Driver
2729
devRoot string
30+
ldconfigPath string
2831
nvidiaCDIHookPath string
2932

3033
isMigDevice bool
@@ -33,6 +36,9 @@ type options struct {
3336
migCaps nvcaps.MigCaps
3437
migCapsError error
3538

39+
// version stores the driver version.
40+
version string
41+
3642
nvsandboxutilslib nvsandboxutils.Interface
3743
}
3844

@@ -45,6 +51,19 @@ func WithDevRoot(root string) Option {
4551
}
4652
}
4753

54+
func WithDriver(driver *root.Driver) Option {
55+
return func(l *options) {
56+
l.driver = driver
57+
}
58+
}
59+
60+
// WithLdconfigPath sets the path to the ldconfig program
61+
func WithLdconfigPath(path string) Option {
62+
return func(l *options) {
63+
l.ldconfigPath = path
64+
}
65+
}
66+
4867
// WithLogger sets the logger for the library
4968
func WithLogger(logger logger.Interface) Option {
5069
return func(l *options) {
@@ -72,3 +91,9 @@ func WithNvsandboxuitilsLib(nvsandboxutilslib nvsandboxutils.Interface) Option {
7291
l.nvsandboxutilslib = nvsandboxutilslib
7392
}
7493
}
94+
95+
func WithVersion(version string) Option {
96+
return func(l *options) {
97+
l.version = version
98+
}
99+
}

pkg/nvcdi/common-nvml.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"fmt"
2121

2222
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
23+
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
2324
)
2425

2526
// newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device.
@@ -41,7 +42,15 @@ func (l *nvmllib) newCommonNVMLDiscoverer(version string) (discover.Discover, er
4142
l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err)
4243
}
4344

44-
driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCDIHookPath, l.ldconfigPath, version)
45+
driverFiles, err := dgpu.NewDriverDiscoverer(
46+
dgpu.WithDevRoot(l.devRoot),
47+
dgpu.WithDriver(l.driver),
48+
dgpu.WithLdconfigPath(l.ldconfigPath),
49+
dgpu.WithLogger(l.logger),
50+
dgpu.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath),
51+
dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib),
52+
dgpu.WithVersion(version),
53+
)
4554
if err != nil {
4655
return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err)
4756
}

pkg/nvcdi/management.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2929
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
3030
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
31+
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
3132
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3233
)
3334

@@ -76,10 +77,18 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
7677

7778
version, err := (*nvcdilib)(m).getDriverVersion()
7879
if err != nil {
79-
return nil, fmt.Errorf("failed to get CUDA version: %v", err)
80+
return nil, fmt.Errorf("failed to get driver version: %v", err)
8081
}
8182

82-
driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCDIHookPath, m.ldconfigPath, version)
83+
driver, err := dgpu.NewDriverDiscoverer(
84+
dgpu.WithDevRoot(m.devRoot),
85+
dgpu.WithDriver(m.driver),
86+
dgpu.WithLdconfigPath(m.ldconfigPath),
87+
dgpu.WithLogger(m.logger),
88+
dgpu.WithNVIDIACDIHookPath(m.nvidiaCDIHookPath),
89+
dgpu.WithNvsandboxuitilsLib(m.nvsandboxutilslib),
90+
dgpu.WithVersion(version),
91+
)
8392
if err != nil {
8493
return nil, fmt.Errorf("failed to create driver library discoverer: %v", err)
8594
}

0 commit comments

Comments
 (0)