@@ -152,8 +152,23 @@ status_t brgemm_t::get_B_pack_type(
152152 return status::success;
153153}
154154
155+ namespace {
156+ // Cache line size used for alignment
157+ constexpr size_t k_cache_line = 64 ;
158+ } // namespace
159+
155160size_t brgemm_t::get_scratchpad_size () const {
156- return brgemm_desc_.get_wsp_buffer_size ();
161+ const size_t wsp_size = brgemm_desc_.get_wsp_buffer_size ();
162+ // Align workspace end to cache line
163+ const size_t wsp_size_aligned = utils::rnd_up (wsp_size, k_cache_line);
164+
165+ const size_t batch_element_size
166+ = brgemm_desc_.brgattr .max_bs * sizeof (brgemm_batch_element_t );
167+ // Align batch element size to cache line
168+ const size_t batch_element_size_aligned
169+ = utils::rnd_up (batch_element_size, k_cache_line);
170+
171+ return wsp_size_aligned + batch_element_size_aligned;
157172}
158173
159174bool brgemm_t::is_execute_postops_valid () const {
@@ -186,8 +201,19 @@ status_t brgemm_t::generate() {
186201
187202status_t brgemm_t::execute (const void *A_ptr, const void *B_ptr,
188203 const dim_t *A_B_offsets, void *C_ptr, void *scratchpad_ptr) const {
204+
205+ if (reinterpret_cast <uintptr_t >(scratchpad_ptr) % k_cache_line != 0
206+ && get_verbose (verbose_t ::exec_profile, component_t ::ukernel))
207+ VWARN (primitive, ukernel, " Scratchpad is not cache-line aligned" );
208+
189209 const auto batch_size = brgemm_desc_.brgattr .max_bs ;
190- std::vector<brgemm_batch_element_t > v_batch_element (batch_size);
210+
211+ // Batch elements at the end of aligned workspace
212+ const size_t wsp_size = brgemm_desc_.get_wsp_buffer_size ();
213+ const size_t wsp_size_aligned = utils::rnd_up (wsp_size, k_cache_line);
214+ auto *v_batch_element = reinterpret_cast <brgemm_batch_element_t *>(
215+ reinterpret_cast <char *>(scratchpad_ptr) + wsp_size_aligned);
216+
191217 for (int i = 0 ; i < batch_size; i++) {
192218 v_batch_element[i].offset .A = A_B_offsets[2 * i];
193219 v_batch_element[i].offset .B = A_B_offsets[2 * i + 1 ];
@@ -196,7 +222,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
196222 if (get_verbose (verbose_t ::exec_profile, component_t ::ukernel)) {
197223 double start_ms = get_msec ();
198224 brgemm_kernel_execute (brgemm_kernel_, batch_size, A_ptr, B_ptr,
199- v_batch_element. data () , C_ptr, scratchpad_ptr,
225+ v_batch_element, C_ptr, scratchpad_ptr,
200226 /* dynamic_values = */ nullptr );
201227 double duration_ms = get_msec () - start_ms;
202228
@@ -206,7 +232,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
206232 duration_ms);
207233 } else {
208234 brgemm_kernel_execute (brgemm_kernel_, batch_size, A_ptr, B_ptr,
209- v_batch_element. data () , C_ptr, scratchpad_ptr,
235+ v_batch_element, C_ptr, scratchpad_ptr,
210236 /* dynamic_values = */ nullptr );
211237 }
212238 return status::success;
@@ -228,8 +254,18 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
228254 }
229255 }
230256
257+ if (reinterpret_cast <uintptr_t >(scratchpad_ptr) % k_cache_line != 0
258+ && get_verbose (verbose_t ::exec_profile, component_t ::ukernel))
259+ VWARN (primitive, ukernel, " Scratchpad is not cache-line aligned" );
260+
231261 const auto batch_size = brgemm_desc_.brgattr .max_bs ;
232- std::vector<brgemm_batch_element_t > v_batch_element (batch_size);
262+
263+ // Batch elements at the end of aligned workspace
264+ const size_t wsp_size = brgemm_desc_.get_wsp_buffer_size ();
265+ const size_t wsp_size_aligned = utils::rnd_up (wsp_size, k_cache_line);
266+ auto *v_batch_element = reinterpret_cast <brgemm_batch_element_t *>(
267+ reinterpret_cast <char *>(scratchpad_ptr) + wsp_size_aligned);
268+
233269 for (int i = 0 ; i < batch_size; i++) {
234270 v_batch_element[i].offset .A = A_B_offsets[2 * i];
235271 v_batch_element[i].offset .B = A_B_offsets[2 * i + 1 ];
@@ -265,7 +301,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
265301 if (get_verbose (verbose_t ::exec_profile, component_t ::ukernel)) {
266302 double start_ms = get_msec ();
267303 brgemm_kernel_execute_postops (brgemm_kernel_, batch_size, A_ptr, B_ptr,
268- v_batch_element. data () , const_cast <void *>(C_ptr), D_ptr,
304+ v_batch_element, const_cast <void *>(C_ptr), D_ptr,
269305 post_ops_data, scratchpad_ptr,
270306 /* dynamic_values = */ nullptr );
271307 double duration_ms = get_msec () - start_ms;
@@ -276,7 +312,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr,
276312 duration_ms);
277313 } else {
278314 brgemm_kernel_execute_postops (brgemm_kernel_, batch_size, A_ptr, B_ptr,
279- v_batch_element. data () , const_cast <void *>(C_ptr), D_ptr,
315+ v_batch_element, const_cast <void *>(C_ptr), D_ptr,
280316 post_ops_data, scratchpad_ptr,
281317 /* dynamic_values = */ nullptr );
282318 }
0 commit comments