|
| 1 | +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +// |
| 3 | +// This file provides utilities for dispatching to specialized versions of functions. |
| 4 | +// This is especially useful for CUDA kernels, since specializing them to particular |
| 5 | +// input sizes can often allow the compiler to unroll loops and place arrays into |
| 6 | +// registers, which can give huge performance speedups. |
| 7 | +// |
| 8 | +// As an example, suppose we have the following function which is specialized |
| 9 | +// based on a compile-time int64_t value: |
| 10 | +// |
| 11 | +// template<typename T, int64_t x> |
| 12 | +// struct SquareOffset { |
| 13 | +// static void run(T y) { |
| 14 | +// T val = x * x + y; |
| 15 | +// std::cout << val << std::endl; |
| 16 | +// } |
| 17 | +// } |
| 18 | +// |
| 19 | +// This function takes one compile-time argument x, and one run-time argument y. |
| 20 | +// We might want to compile specialized versions of this for x=0, x=1, etc and |
| 21 | +// then dispatch to the correct one based on the runtime value of x. |
| 22 | +// One simple way to achieve this is with a lookup table: |
| 23 | +// |
| 24 | +// template<typename T> |
| 25 | +// void DispatchSquareOffset(const int64_t x, T y) { |
| 26 | +// if (x == 0) { |
| 27 | +// SquareOffset<T, 0>::run(y); |
| 28 | +// } else if (x == 1) { |
| 29 | +// SquareOffset<T, 1>::run(y); |
| 30 | +// } else if (x == 2) { |
| 31 | +// SquareOffset<T, 2>::run(y); |
| 32 | +// } |
| 33 | +// } |
| 34 | +// |
| 35 | +// This function takes both x and y as run-time arguments, and dispatches to |
| 36 | +// different specialized versions of SquareOffset based on the run-time value |
| 37 | +// of x. This works, but it's tedious and error-prone. If we want to change the |
| 38 | +// set of x values for which we provide compile-time specializations, then we |
| 39 | +// will need to do a lot of tedius editing of the dispatch function. Also, if we |
| 40 | +// want to provide compile-time specializations for another function other than |
| 41 | +// SquareOffset, we will need to duplicate the entire lookup table. |
| 42 | +// |
| 43 | +// To solve these problems, we can use the DispatchKernel1D function provided by |
| 44 | +// this file instead: |
| 45 | +// |
| 46 | +// template<typename T> |
| 47 | +// void DispatchSquareOffset(const int64_t x, T y) { |
| 48 | +// constexpr int64_t xmin = 0; |
| 49 | +// constexpr int64_t xmax = 2; |
| 50 | +// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y); |
| 51 | +// } |
| 52 | +// |
| 53 | +// DispatchKernel1D uses template metaprogramming to compile specialized |
| 54 | +// versions of SquareOffset for all values of x with xmin <= x <= xmax, and |
| 55 | +// then dispatches to the correct one based on the run-time value of x. If we |
| 56 | +// want to change the range of x values for which SquareOffset is specialized |
| 57 | +// at compile-time, then all we have to do is change the values of the |
| 58 | +// compile-time constants xmin and xmax. |
| 59 | +// |
| 60 | +// This file also allows us to similarly dispatch functions that depend on two |
| 61 | +// compile-time int64_t values, using the DispatchKernel2D function like this: |
| 62 | +// |
| 63 | +// template<typename T, int64_t x, int64_t y> |
| 64 | +// struct Sum { |
| 65 | +// static void run(T z, T w) { |
| 66 | +// T val = x + y + z + w; |
| 67 | +// std::cout << val << std::endl; |
| 68 | +// } |
| 69 | +// } |
| 70 | +// |
| 71 | +// template<typename T> |
| 72 | +// void DispatchSum(const int64_t x, const int64_t y, int z, int w) { |
| 73 | +// constexpr int64_t xmin = 1; |
| 74 | +// constexpr int64_t xmax = 3; |
| 75 | +// constexpr int64_t ymin = 2; |
| 76 | +// constexpr int64_t ymax = 5; |
| 77 | +// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w); |
| 78 | +// } |
| 79 | +// |
| 80 | +// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to |
| 81 | +// compile specialized versions of sum for all values of (x, y) with |
| 82 | +// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct |
| 83 | +// specialized version based on the runtime values of x and y. |
| 84 | + |
| 85 | +// Define some helper structs in an anonymous namespace. |
| 86 | +namespace { |
| 87 | + |
| 88 | +// 1D dispatch: general case. |
| 89 | +// Kernel is the function we want to dispatch to; it should take a typename and |
| 90 | +// an int64_t as template args, and it should define a static void function |
| 91 | +// run which takes any number of arguments of any type. |
| 92 | +// In order to dispatch, we will take an additional template argument curN, |
| 93 | +// and increment it via template recursion until it is equal to the run-time |
| 94 | +// argument N. |
| 95 | +template< |
| 96 | + template<typename, int64_t> class Kernel, |
| 97 | + typename T, |
| 98 | + int64_t minN, |
| 99 | + int64_t maxN, |
| 100 | + int64_t curN, |
| 101 | + typename... Args |
| 102 | +> |
| 103 | +struct DispatchKernelHelper1D { |
| 104 | + static void run(const int64_t N, Args... args) { |
| 105 | + if (curN == N) { |
| 106 | + // The compile-time value curN is equal to the run-time value N, so we |
| 107 | + // can dispatch to the run method of the Kernel. |
| 108 | + Kernel<T, curN>::run(args...); |
| 109 | + } else if (curN < N) { |
| 110 | + // Increment curN via template recursion |
| 111 | + DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...); |
| 112 | + } |
| 113 | + // We shouldn't get here -- throw an error? |
| 114 | + } |
| 115 | +}; |
| 116 | + |
| 117 | + |
| 118 | +// 1D dispatch: Specialization when curN == maxN |
| 119 | +// We need this base case to avoid infinite template recursion. |
| 120 | +template< |
| 121 | + template<typename, int64_t> class Kernel, |
| 122 | + typename T, |
| 123 | + int64_t minN, |
| 124 | + int64_t maxN, |
| 125 | + typename... Args |
| 126 | +> |
| 127 | +struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> { |
| 128 | + static void run(const int64_t N, Args... args) { |
| 129 | + if (N == maxN) { |
| 130 | + Kernel<T, maxN>::run(args...); |
| 131 | + } |
| 132 | + // We shouldn't get here -- throw an error? |
| 133 | + } |
| 134 | +}; |
| 135 | + |
| 136 | + |
| 137 | +// 2D dispatch, general case. |
| 138 | +// This is similar to the 1D case: we take additional template args curN and |
| 139 | +// curM, and increment them via template recursion until they are equal to |
| 140 | +// the run-time values of N and M, at which point we dispatch to the run |
| 141 | +// method of the kernel. |
| 142 | +template< |
| 143 | + template<typename, int64_t, int64_t> class Kernel, |
| 144 | + typename T, |
| 145 | + int64_t minN, int64_t maxN, int64_t curN, |
| 146 | + int64_t minM, int64_t maxM, int64_t curM, |
| 147 | + typename... Args |
| 148 | +> |
| 149 | +struct DispatchKernelHelper2D { |
| 150 | + static void run(const int64_t N, const int64_t M, Args... args) { |
| 151 | + if (curN == N && curM == M) { |
| 152 | + Kernel<T, curN, curM>::run(args...); |
| 153 | + } else if (curN < N && curM < M) { |
| 154 | + // Increment both curN and curM. This isn't strictly necessary; we could |
| 155 | + // just increment one or the other at each step. But this helps to cut |
| 156 | + // on the number of recursive calls we make. |
| 157 | + DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...); |
| 158 | + } else if (curN < N) { |
| 159 | + // Increment curN only |
| 160 | + DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...); |
| 161 | + } else if (curM < M) { |
| 162 | + // Increment curM only |
| 163 | + DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...); |
| 164 | + } |
| 165 | + } |
| 166 | +}; |
| 167 | + |
| 168 | + |
| 169 | +// 2D dispatch, specialization for curN == maxN |
| 170 | +template< |
| 171 | + template<typename, int64_t, int64_t> class Kernel, |
| 172 | + typename T, |
| 173 | + int64_t minN, int64_t maxN, |
| 174 | + int64_t minM, int64_t maxM, int64_t curM, |
| 175 | + typename... Args |
| 176 | +> |
| 177 | +struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> { |
| 178 | + static void run(const int64_t N, const int64_t M, Args... args) { |
| 179 | + if (maxN == N && curM == M) { |
| 180 | + Kernel<T, maxN, curM>::run(args...); |
| 181 | + } else if (curM < maxM) { |
| 182 | + DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...); |
| 183 | + } |
| 184 | + // We should not get here -- throw an error? |
| 185 | + } |
| 186 | +}; |
| 187 | + |
| 188 | + |
| 189 | +// 2D dispatch, specialization for curM == maxM |
| 190 | +template< |
| 191 | + template<typename, int64_t, int64_t> class Kernel, |
| 192 | + typename T, |
| 193 | + int64_t minN, int64_t maxN, int64_t curN, |
| 194 | + int64_t minM, int64_t maxM, |
| 195 | + typename... Args |
| 196 | +> |
| 197 | +struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> { |
| 198 | + static void run(const int64_t N, const int64_t M, Args... args) { |
| 199 | + if (curN == N && maxM == M) { |
| 200 | + Kernel<T, curN, maxM>::run(args...); |
| 201 | + } else if (curN < maxN) { |
| 202 | + DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...); |
| 203 | + } |
| 204 | + // We should not get here -- throw an error? |
| 205 | + } |
| 206 | +}; |
| 207 | + |
| 208 | + |
| 209 | +// 2D dispatch, specialization for curN == maxN, curM == maxM |
| 210 | +template< |
| 211 | + template<typename, int64_t, int64_t> class Kernel, |
| 212 | + typename T, |
| 213 | + int64_t minN, int64_t maxN, |
| 214 | + int64_t minM, int64_t maxM, |
| 215 | + typename... Args |
| 216 | +> |
| 217 | +struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> { |
| 218 | + static void run(const int64_t N, const int64_t M, Args... args) { |
| 219 | + if (maxN == N && maxM == M) { |
| 220 | + Kernel<T, maxN, maxM>::run(args...); |
| 221 | + } |
| 222 | + // We should not get here -- throw an error? |
| 223 | + } |
| 224 | +}; |
| 225 | + |
| 226 | +} // namespace |
| 227 | + |
| 228 | + |
| 229 | +// This is the function we expect users to call to dispatch to 1D functions |
| 230 | +template< |
| 231 | + template<typename, int64_t> class Kernel, |
| 232 | + typename T, |
| 233 | + int64_t minN, |
| 234 | + int64_t maxN, |
| 235 | + typename... Args |
| 236 | +> |
| 237 | +void DispatchKernel1D(const int64_t N, Args... args) { |
| 238 | + if (minN <= N && N <= maxN) { |
| 239 | + // Kick off the template recursion by calling the Helper with curN = minN |
| 240 | + DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...); |
| 241 | + } |
| 242 | + // Maybe throw an error if we tried to dispatch outside the allowed range? |
| 243 | +} |
| 244 | + |
| 245 | + |
| 246 | +// This is the function we expect users to call to dispatch to 2D functions |
| 247 | +template< |
| 248 | + template<typename, int64_t, int64_t> class Kernel, |
| 249 | + typename T, |
| 250 | + int64_t minN, int64_t maxN, |
| 251 | + int64_t minM, int64_t maxM, |
| 252 | + typename... Args |
| 253 | +> |
| 254 | +void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { |
| 255 | + if (minN <= N && N <= maxN && minM <= M && M <= maxM) { |
| 256 | + // Kick off the template recursion by calling the Helper with curN = minN |
| 257 | + // and curM = minM |
| 258 | + DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...); |
| 259 | + } |
| 260 | + // Maybe throw an error if we tried to dispatch outside the specified range? |
| 261 | +} |
0 commit comments