Skip to content

Commit c933cf1

Browse files
committed
Merge branch 'main' into moe_config_diff_resolve
2 parents a8820c1 + 39456f3 commit c933cf1

File tree

14 files changed

+575
-67
lines changed

14 files changed

+575
-67
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
3434
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
3535

3636
# Supported AMD GPU architectures.
37-
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
37+
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
3838

3939
#
4040
# Supported/expected torch versions for CUDA/ROCm.

csrc/quantization/fp8/amd/hip_float8_impl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

3-
#if defined(__HIPCC__) && \
4-
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
3+
#if defined(__HIPCC__) && defined(__gfx942__)
54
#define __HIP__MI300__
65
#endif
76

csrc/rocm/attention.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
#include "../attention/dtype_fp8.cuh"
2525
#include "../quantization/fp8/amd/quant_utils.cuh"
2626

27-
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
28-
defined(__gfx941__) || defined(__gfx942__))
27+
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
2928
#define __HIP__MI300_MI250__
3029
#endif
3130

csrc/rocm/custom.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
4848
at::cuda::getCurrentCUDAStream(), CuCount);
4949
}
5050

51+
void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
52+
void* scale_b, const int M, const int K, const int Kp,
53+
const int N, const int Otp_in, cudaStream_t stream,
54+
const int CuCount);
55+
56+
void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
57+
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in,
58+
const int64_t Otp_in, const int64_t CuCount) {
59+
auto M = in_a.size(0);
60+
auto K = in_a.size(1);
61+
auto Kp = in_a.stride(0);
62+
int N = N_in;
63+
int Otp = Otp_in;
64+
wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(),
65+
scale_a.data_ptr(), scale_b.data_ptr(), M, K, Kp, N, Otp,
66+
at::cuda::getCurrentCUDAStream(), CuCount);
67+
}
68+
5169
void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
5270
cudaStream_t stream, const int solidx);
5371

0 commit comments

Comments
 (0)