Skip to content

Commit dd15ce4

Browse files
committed
Add CUDA impl
1 parent 4bb6330 commit dd15ce4

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

source/adapters/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_ur_adapter(${TARGET_NAME}
3838
${CMAKE_CURRENT_SOURCE_DIR}/queue.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/sampler.hpp
4040
${CMAKE_CURRENT_SOURCE_DIR}/sampler.cpp
41+
${CMAKE_CURRENT_SOURCE_DIR}/tensor_map.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/tracing.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/usm.cpp
4344
${CMAKE_CURRENT_SOURCE_DIR}/usm_p2p.cpp

source/adapters/cuda/tensor_map.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
//===--------- tensor_map.cpp - CUDA Adapter ------------------------------===//
2+
//
3+
// Copyright (C) 2024 Intel Corporation
4+
//
5+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See LICENSE.TXT
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include <cuda.h>
12+
#include <ur_api.h>
13+
14+
#include "context.hpp"
15+
16+
struct ur_exp_tensor_map_handle_t_ {
17+
CUtensorMap Map;
18+
};
19+
20+
#define CONVERT(URTYPE, CUTYPE) \
21+
if (URTYPE & UrType) \
22+
return CUTYPE;
23+
24+
inline CUtensorMapDataType
25+
convertUrToCuDataType(ur_exp_tensor_map_data_type_flags_t UrType) {
26+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT8,
27+
CU_TENSOR_MAP_DATA_TYPE_UINT8);
28+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT16,
29+
CU_TENSOR_MAP_DATA_TYPE_UINT16);
30+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT32,
31+
CU_TENSOR_MAP_DATA_TYPE_UINT32);
32+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_INT32,
33+
CU_TENSOR_MAP_DATA_TYPE_INT32);
34+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_UINT64,
35+
CU_TENSOR_MAP_DATA_TYPE_UINT64);
36+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_INT64,
37+
CU_TENSOR_MAP_DATA_TYPE_INT64);
38+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT16,
39+
CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
40+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT32,
41+
CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
42+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT64,
43+
CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
44+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_BFLOAT16,
45+
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
46+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_FLOAT32_FTZ,
47+
CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ);
48+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_TFLOAT32,
49+
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32);
50+
CONVERT(UR_EXP_TENSOR_MAP_DATA_TYPE_FLAG_TFLOAT32_FTZ,
51+
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ);
52+
throw "convertUrToCuDataType failed!";
53+
}
54+
55+
CUtensorMapInterleave
56+
convertUrToCuInterleave(ur_exp_tensor_map_interleave_flags_t UrType) {
57+
CONVERT(UR_EXP_TENSOR_MAP_INTERLEAVE_FLAG_NONE,
58+
CU_TENSOR_MAP_INTERLEAVE_NONE);
59+
CONVERT(UR_EXP_TENSOR_MAP_INTERLEAVE_FLAG_16B, CU_TENSOR_MAP_INTERLEAVE_16B);
60+
CONVERT(UR_EXP_TENSOR_MAP_INTERLEAVE_FLAG_32B, CU_TENSOR_MAP_INTERLEAVE_32B);
61+
throw "convertUrToCuInterleave failed!";
62+
}
63+
64+
CUtensorMapSwizzle
65+
convertUrToCuSwizzle(ur_exp_tensor_map_swizzle_flags_t UrType) {
66+
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_NONE, CU_TENSOR_MAP_SWIZZLE_NONE);
67+
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_32B, CU_TENSOR_MAP_SWIZZLE_32B);
68+
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_64B, CU_TENSOR_MAP_SWIZZLE_64B);
69+
CONVERT(UR_EXP_TENSOR_MAP_SWIZZLE_FLAG_128B, CU_TENSOR_MAP_SWIZZLE_128B);
70+
throw "convertUrToCuSwizzle failed!";
71+
}
72+
73+
CUtensorMapL2promotion
74+
convertUrToL2promotion(ur_exp_tensor_map_l2_promotion_flags_t UrType) {
75+
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_NONE,
76+
CU_TENSOR_MAP_L2_PROMOTION_NONE);
77+
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_64B,
78+
CU_TENSOR_MAP_L2_PROMOTION_L2_64B);
79+
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_128B,
80+
CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
81+
CONVERT(UR_EXP_TENSOR_MAP_L2_PROMOTION_FLAG_256B,
82+
CU_TENSOR_MAP_L2_PROMOTION_L2_256B);
83+
throw "convertUrToCul2promotion failed!";
84+
}
85+
86+
CUtensorMapFloatOOBfill
87+
convertUrToCuOOBfill(ur_exp_tensor_map_oob_fill_flags_t UrType) {
88+
CONVERT(UR_EXP_TENSOR_MAP_OOB_FILL_FLAG_NONE,
89+
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
90+
CONVERT(UR_EXP_TENSOR_MAP_OOB_FILL_FLAG_REQUEST_ZERO_FMA,
91+
CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA);
92+
throw "convertUrToCuDataOOBfill failed!";
93+
}
94+
95+
UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeIm2ColExp(
96+
ur_device_handle_t hDevice,
97+
ur_exp_tensor_map_data_type_flags_t TensorMapType, uint32_t TensorRank,
98+
void *GlobalAddress, const uint64_t *GlobalDim,
99+
const uint64_t *GlobalStrides, const int *PixelBoxLowerCorner,
100+
const int *PixelBoxUpperCorner, uint32_t ChannelsPerPixel,
101+
uint32_t PixelsPerColumn, const uint32_t *ElementStrides,
102+
ur_exp_tensor_map_interleave_flags_t Interleave,
103+
ur_exp_tensor_map_swizzle_flags_t Swizzle,
104+
ur_exp_tensor_map_l2_promotion_flags_t L2Promotion,
105+
ur_exp_tensor_map_oob_fill_flags_t OobFill,
106+
ur_exp_tensor_map_handle_t *hTensorMap) {
107+
ScopedContext Active(hDevice);
108+
try {
109+
UR_CHECK_ERROR(cuTensorMapEncodeIm2col(
110+
&(*hTensorMap)->Map, convertUrToCuDataType(TensorMapType), TensorRank,
111+
GlobalAddress, GlobalDim, GlobalStrides, PixelBoxLowerCorner,
112+
PixelBoxUpperCorner, ChannelsPerPixel, PixelsPerColumn, ElementStrides,
113+
convertUrToCuInterleave(Interleave), convertUrToCuSwizzle(Swizzle),
114+
convertUrToL2promotion(L2Promotion), convertUrToCuOOBfill(OobFill)));
115+
} catch (ur_result_t Err) {
116+
return Err;
117+
}
118+
return UR_RESULT_SUCCESS;
119+
}
120+
UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeTiledExp(
121+
ur_device_handle_t hDevice,
122+
ur_exp_tensor_map_data_type_flags_t TensorMapType, uint32_t TensorRank,
123+
void *GlobalAddress, const uint64_t *GlobalDim,
124+
const uint64_t *GlobalStrides, const uint32_t *BoxDim,
125+
const uint32_t *ElementStrides,
126+
ur_exp_tensor_map_interleave_flags_t Interleave,
127+
ur_exp_tensor_map_swizzle_flags_t Swizzle,
128+
ur_exp_tensor_map_l2_promotion_flags_t L2Promotion,
129+
ur_exp_tensor_map_oob_fill_flags_t OobFill,
130+
ur_exp_tensor_map_handle_t *hTensorMap) {
131+
ScopedContext Active(hDevice);
132+
try {
133+
UR_CHECK_ERROR(cuTensorMapEncodeTiled(
134+
&(*hTensorMap)->Map, convertUrToCuDataType(TensorMapType), TensorRank,
135+
GlobalAddress, GlobalDim, GlobalStrides, BoxDim, ElementStrides,
136+
convertUrToCuInterleave(Interleave), convertUrToCuSwizzle(Swizzle),
137+
convertUrToL2promotion(L2Promotion), convertUrToCuOOBfill(OobFill)));
138+
} catch (ur_result_t Err) {
139+
return Err;
140+
}
141+
return UR_RESULT_SUCCESS;
142+
}

0 commit comments

Comments
 (0)