Skip to content

Commit 37b271b

Browse files
akalininankalinin
authored andcommitted
x64: ukernel: use part of scratchpad as offsets buffer.
1 parent a0b185f commit 37b271b

File tree

1 file changed

+43
-7
lines changed

1 file changed

+43
-7
lines changed

src/cpu/x64/ukernel/brgemm.cpp

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,23 @@ status_t brgemm_t::get_B_pack_type(
152152
return status::success;
153153
}
154154

155+
namespace {
156+
// Cache line size used for alignment
157+
constexpr size_t k_cache_line = 64;
158+
} // namespace
159+
155160
size_t brgemm_t::get_scratchpad_size() const {
156-
return brgemm_desc_.get_wsp_buffer_size();
161+
const size_t wsp_size = brgemm_desc_.get_wsp_buffer_size();
162+
// Align workspace end to cache line
163+
const size_t wsp_size_aligned = utils::rnd_up(wsp_size, k_cache_line);
164+
165+
const size_t batch_element_size
166+
= brgemm_desc_.brgattr.max_bs * sizeof(brgemm_batch_element_t);
167+
// Align batch element size to cache line
168+
const size_t batch_element_size_aligned
169+
= utils::rnd_up(batch_element_size, k_cache_line);
170+
171+
return wsp_size_aligned + batch_element_size_aligned;
157172
}
158173

159174
bool brgemm_t::is_execute_postops_valid() const {
@@ -186,8 +201,19 @@ status_t brgemm_t::generate() {
186201

187202
status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
188203
const dim_t *A_B_offsets, void *C_ptr, void *scratchpad_ptr) const {
204+
205+
if (reinterpret_cast<uintptr_t>(scratchpad_ptr) % k_cache_line != 0
206+
&& get_verbose(verbose_t::exec_profile, component_t::ukernel))
207+
VWARN(primitive, ukernel, "Scratchpad is not cache-line aligned");
208+
189209
const auto batch_size = brgemm_desc_.brgattr.max_bs;
190-
std::vector<brgemm_batch_element_t> v_batch_element(batch_size);
210+
211+
// Batch elements at the end of aligned workspace
212+
const size_t wsp_size = brgemm_desc_.get_wsp_buffer_size();
213+
const size_t wsp_size_aligned = utils::rnd_up(wsp_size, k_cache_line);
214+
auto *v_batch_element = reinterpret_cast<brgemm_batch_element_t *>(
215+
reinterpret_cast<char *>(scratchpad_ptr) + wsp_size_aligned);
216+
191217
for (int i = 0; i < batch_size; i++) {
192218
v_batch_element[i].offset.A = A_B_offsets[2 * i];
193219
v_batch_element[i].offset.B = A_B_offsets[2 * i + 1];
@@ -196,7 +222,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
196222
if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) {
197223
double start_ms = get_msec();
198224
brgemm_kernel_execute(brgemm_kernel_, batch_size, A_ptr, B_ptr,
199-
v_batch_element.data(), C_ptr, scratchpad_ptr,
225+
v_batch_element, C_ptr, scratchpad_ptr,
200226
/* dynamic_values = */ nullptr);
201227
double duration_ms = get_msec() - start_ms;
202228

@@ -206,7 +232,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
206232
duration_ms);
207233
} else {
208234
brgemm_kernel_execute(brgemm_kernel_, batch_size, A_ptr, B_ptr,
209-
v_batch_element.data(), C_ptr, scratchpad_ptr,
235+
v_batch_element, C_ptr, scratchpad_ptr,
210236
/* dynamic_values = */ nullptr);
211237
}
212238
return status::success;
@@ -228,8 +254,18 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
228254
}
229255
}
230256

257+
if (reinterpret_cast<uintptr_t>(scratchpad_ptr) % k_cache_line != 0
258+
&& get_verbose(verbose_t::exec_profile, component_t::ukernel))
259+
VWARN(primitive, ukernel, "Scratchpad is not cache-line aligned");
260+
231261
const auto batch_size = brgemm_desc_.brgattr.max_bs;
232-
std::vector<brgemm_batch_element_t> v_batch_element(batch_size);
262+
263+
// Batch elements at the end of aligned workspace
264+
const size_t wsp_size = brgemm_desc_.get_wsp_buffer_size();
265+
const size_t wsp_size_aligned = utils::rnd_up(wsp_size, k_cache_line);
266+
auto *v_batch_element = reinterpret_cast<brgemm_batch_element_t *>(
267+
reinterpret_cast<char *>(scratchpad_ptr) + wsp_size_aligned);
268+
233269
for (int i = 0; i < batch_size; i++) {
234270
v_batch_element[i].offset.A = A_B_offsets[2 * i];
235271
v_batch_element[i].offset.B = A_B_offsets[2 * i + 1];
@@ -265,7 +301,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
265301
if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) {
266302
double start_ms = get_msec();
267303
brgemm_kernel_execute_postops(brgemm_kernel_, batch_size, A_ptr, B_ptr,
268-
v_batch_element.data(), const_cast<void *>(C_ptr), D_ptr,
304+
v_batch_element, const_cast<void *>(C_ptr), D_ptr,
269305
post_ops_data, scratchpad_ptr,
270306
/* dynamic_values = */ nullptr);
271307
double duration_ms = get_msec() - start_ms;
@@ -276,7 +312,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
276312
duration_ms);
277313
} else {
278314
brgemm_kernel_execute_postops(brgemm_kernel_, batch_size, A_ptr, B_ptr,
279-
v_batch_element.data(), const_cast<void *>(C_ptr), D_ptr,
315+
v_batch_element, const_cast<void *>(C_ptr), D_ptr,
280316
post_ops_data, scratchpad_ptr,
281317
/* dynamic_values = */ nullptr);
282318
}

0 commit comments

Comments
 (0)