-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
Copy pathnaive_matrix_multiplication.mojo
executable file
·117 lines (94 loc) · 3.87 KB
/
naive_matrix_multiplication.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from math import ceildiv
from sys import has_nvidia_gpu_accelerator
from gpu.host import Dim
from gpu.id import block_dim, block_idx, thread_idx
from layout import Layout, LayoutTensor
from max.driver import Accelerator, Device, Tensor, accelerator, cpu
alias float_dtype = DType.float32
alias tensor_rank = 2
fn naive_matrix_multiplication[
m_layout: Layout,
n_layout: Layout,
p_layout: Layout,
](
m: LayoutTensor[float_dtype, m_layout, MutableAnyOrigin],
n: LayoutTensor[float_dtype, n_layout, MutableAnyOrigin],
p: LayoutTensor[float_dtype, p_layout, MutableAnyOrigin],
):
"""Naive matrix multiplication of M_ij x N_jk = P_ik."""
row = block_dim.y * block_idx.y + thread_idx.y
col = block_dim.x * block_idx.x + thread_idx.x
var m_dim = p.dim(0)
var n_dim = p.dim(1)
var k_dim = m.dim(1)
if row < m_dim and col < n_dim:
for j_index in range(k_dim):
p[row, col] = p[row, col] + m[row, j_index] * n[j_index, col]
def main():
# Attempt to connect to a compatible GPU. If one is not found, this will
# error out and exit.
gpu_device = accelerator()
host_device = cpu()
alias I = 5
alias J = 4
alias K = 6
# Allocate the two input matrices on the host.
m_tensor = Tensor[float_dtype, tensor_rank]((I, J), host_device)
n_tensor = Tensor[float_dtype, tensor_rank]((J, K), host_device)
# Fill them with initial values.
for m_row in range(I):
for m_col in range(J):
m_tensor[m_row, m_col] = m_row - m_col
for n_row in range(J):
for n_col in range(K):
n_tensor[n_row, n_col] = n_row + n_col
print("M matrix:", m_tensor)
print("N matrix:", n_tensor)
# Move the input matrices to the accelerator.
m_tensor = m_tensor.move_to(gpu_device)
n_tensor = n_tensor.move_to(gpu_device)
# Allocate a tensor on the accelerator to host the calculation results.
p_tensor = Tensor[float_dtype, tensor_rank]((I, K), gpu_device)
m_layout_tensor = m_tensor.to_layout_tensor()
n_layout_tensor = n_tensor.to_layout_tensor()
p_layout_tensor = p_tensor.to_layout_tensor()
# Compile the function to run across a grid on the GPU.
gpu_function = Accelerator.compile[
naive_matrix_multiplication[
m_layout_tensor.layout,
n_layout_tensor.layout,
p_layout_tensor.layout,
]
](gpu_device)
# The grid is divided up into blocks, making sure there's an extra
# full block for any remainder. This hasn't been tuned for any specific
# GPU.
alias BLOCK_SIZE = 16
num_col_blocks = ceildiv(I, BLOCK_SIZE)
num_row_blocks = ceildiv(J, BLOCK_SIZE)
# Launch the compiled function on the GPU. The target device is specified
# first, followed by all function arguments. The last two named parameters
# are the dimensions of the grid in blocks, and the block dimensions.
gpu_function(
gpu_device,
m_layout_tensor,
n_layout_tensor,
p_layout_tensor,
grid_dim=Dim(num_col_blocks, num_row_blocks),
block_dim=Dim(BLOCK_SIZE, BLOCK_SIZE),
)
# Move the output tensor back onto the CPU so that we can read the results.
p_tensor = p_tensor.move_to(host_device)
print("Resulting matrix:", p_tensor)