Skip to content

Commit 09fb5c0

Browse files
committed
xe: move alg_kind_t out of kernel_ctx
Replaces alg_kind_t defines with an interop header. This is intended to reduce the excessive noise this introduces when analyzing the OpenCL compilation defines.
1 parent 37b271b commit 09fb5c0

File tree

13 files changed

+246
-145
lines changed

13 files changed

+246
-145
lines changed

src/gpu/intel/binary/common.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define GPU_INTEL_BINARY_COMMON_H
1919

2020
#include "gpu/intel/include/dispatch.h"
21+
#include "gpu/intel/include/dnnl_interop.h"
2122
#include "gpu/intel/include/post_ops.h"
2223
#include "gpu/intel/include/types.h"
2324
#include "gpu/intel/include/utils.h"
@@ -371,28 +372,28 @@
371372
#define DEF_binary_op(dt, special_dt) \
372373
dt __attribute__((overloadable)) binary_op(int alg, dt src0, dt src1) { \
373374
switch (alg) { \
374-
case (BINARY_ADD): return src0 + src1; \
375-
case (BINARY_MUL): return src0 * src1; \
376-
case (BINARY_MAX): return max(src0, src1); \
377-
case (BINARY_MIN): return min(src0, src1); \
378-
case (BINARY_DIV): return src0 / src1; \
379-
case (BINARY_SUB): return src0 - src1; \
380-
case (BINARY_GE): \
375+
case binary_add: return src0 + src1; \
376+
case binary_mul: return src0 * src1; \
377+
case binary_max: return max(src0, src1); \
378+
case binary_min: return min(src0, src1); \
379+
case binary_div: return src0 / src1; \
380+
case binary_sub: return src0 - src1; \
381+
case binary_ge: \
381382
return (src0 >= src1) ? SPECIAL(special_dt, one) \
382383
: SPECIAL(special_dt, zero); \
383-
case (BINARY_GT): \
384+
case binary_gt: \
384385
return (src0 > src1) ? SPECIAL(special_dt, one) \
385386
: SPECIAL(special_dt, zero); \
386-
case (BINARY_LE): \
387+
case binary_le: \
387388
return (src0 <= src1) ? SPECIAL(special_dt, one) \
388389
: SPECIAL(special_dt, zero); \
389-
case (BINARY_LT): \
390+
case binary_lt: \
390391
return (src0 < src1) ? SPECIAL(special_dt, one) \
391392
: SPECIAL(special_dt, zero); \
392-
case (BINARY_EQ): \
393+
case binary_eq: \
393394
return (src0 == src1) ? SPECIAL(special_dt, one) \
394395
: SPECIAL(special_dt, zero); \
395-
case (BINARY_NE): \
396+
case binary_ne: \
396397
return (src0 != src1) ? SPECIAL(special_dt, one) \
397398
: SPECIAL(special_dt, zero); \
398399
} \

src/gpu/intel/binary/simple.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ status_t simple_t::pd_t::init_conf(impl::engine_t *engine) {
125125

126126
status_t simple_t::pd_t::init_kernel_ctx(
127127
compute::kernel_ctx_t &kernel_ctx) const {
128-
def_binary_alg_kinds(kernel_ctx);
129128
kernel_ctx.define_int("BINARY_ALG", conf.alg);
130129
kernel_ctx.define_int("IS_TERNARY", (conf.alg == alg_kind::binary_select));
131130

src/gpu/intel/binary/xe.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ status_t xe_t::pd_t::init_conf(impl::engine_t *engine) {
285285
}
286286

287287
status_t xe_t::pd_t::init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const {
288-
def_binary_alg_kinds(kernel_ctx);
289288
kernel_ctx.define_int("BINARY_ALG", conf.alg);
290289
kernel_ctx.define_int(
291290
"IS_TERNARY", (conf.alg == alg_kind::binary_select) ? 1 : 0);

src/gpu/intel/eltwise/ref.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
6767
const memory_desc_t *dst_md) {
6868
kernel_ctx.set_data_type(conf.data_type);
6969

70-
def_eltwise_alg_kinds(kernel_ctx);
71-
7270
kernel_ctx.define_int("ELTWISE_ALG", conf.alg);
7371
kernel_ctx.define_int("NDIMS", conf.ndims);
7472
kernel_ctx.define_int("GWS0", conf.dispatch.nd_range().global_range()[0]);

src/gpu/intel/eltwise/xe.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ compute::kernel_ctx_t xe_jit_params_t::get_kernel_ctx() const {
6767
compute::kernel_ctx_t kernel_ctx;
6868

6969
kernel_ctx.set_data_type(data_type);
70-
def_eltwise_alg_kinds(kernel_ctx);
7170

7271
kernel_ctx.define_int("ELTWISE_ALG", alg_kind);
7372

src/gpu/intel/gemm/ref.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ struct ref_jit_params_t : public trivially_serializable_t<ref_jit_params_t> {
5555

5656
kernel_ctx.define_int("WITH_POST_OP", with_post_ops);
5757
if (with_post_ops) {
58-
def_binary_alg_kinds(kernel_ctx);
59-
def_eltwise_alg_kinds(kernel_ctx);
6058
kernel_ctx.define_int("ELTWISE_ALG", eltwise_alg);
6159
}
6260
kernel_ctx.define_int("WITH_SUM", with_sum);
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*******************************************************************************
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
#ifndef GPU_INTEL_INCLUDE_DNNL_INTEROP_H
17+
#define GPU_INTEL_INCLUDE_DNNL_INTEROP_H
18+
19+
// Intended for use in both OpenCL C and C++ code. This is largely intended to
20+
// enable validation of the portion of the oneDNN C API that is required within
21+
// OpenCL kernels.
22+
23+
#ifndef __OPENCL_VERSION__
24+
namespace {
25+
#endif
26+
27+
enum dnnl_ocl_alg_kind_t {
28+
// Eltwise algorithm kinds
29+
eltwise_relu = 0x20,
30+
eltwise_tanh = 0x21,
31+
eltwise_elu = 0x22,
32+
eltwise_square = 0x23,
33+
eltwise_abs = 0x24,
34+
eltwise_sqrt = 0x25,
35+
eltwise_linear = 0x26,
36+
eltwise_soft_relu = 0x27,
37+
eltwise_hardsigmoid = 0x28,
38+
eltwise_logistic = 0x29,
39+
eltwise_exp = 0x2a,
40+
eltwise_gelu_tanh = 0x2b,
41+
eltwise_swish = 0x2c,
42+
eltwise_log = 0x2d,
43+
eltwise_clip = 0x2e,
44+
eltwise_clip_v2 = 0x2f,
45+
eltwise_pow = 0x30,
46+
eltwise_gelu_erf = 0x31,
47+
eltwise_round = 0x32,
48+
eltwise_mish = 0x33,
49+
eltwise_hardswish = 0x34,
50+
51+
eltwise_relu_dst = 0x100,
52+
eltwise_tanh_dst = 0x101,
53+
eltwise_elu_dst = 0x102,
54+
eltwise_sqrt_dst = 0x103,
55+
eltwise_logistic_dst = 0x104,
56+
eltwise_exp_dst = 0x105,
57+
eltwise_clip_v2_dst = 0x106,
58+
59+
// Binary alg kinds
60+
binary_add = 0x1fff0,
61+
binary_mul = 0x1fff1,
62+
binary_max = 0x1fff2,
63+
binary_min = 0x1fff3,
64+
binary_div = 0x1fff4,
65+
binary_sub = 0x1fff5,
66+
binary_ge = 0x1fff6,
67+
binary_gt = 0x1fff7,
68+
binary_le = 0x1fff8,
69+
binary_lt = 0x1fff9,
70+
binary_eq = 0x1fffa,
71+
binary_ne = 0x1fffb,
72+
binary_select = 0x1fffc,
73+
};
74+
75+
#ifndef __OPENCL_VERSION__
76+
#include "dnnl_types.h"
77+
78+
// Eltwise algorithm kinds
79+
static_assert((int)eltwise_relu == dnnl_eltwise_relu, "dnnl API mismatch");
80+
static_assert((int)eltwise_tanh == dnnl_eltwise_tanh, "dnnl API mismatch");
81+
static_assert((int)eltwise_elu == dnnl_eltwise_elu, "dnnl API mismatch");
82+
static_assert((int)eltwise_square == dnnl_eltwise_square, "dnnl API mismatch");
83+
static_assert((int)eltwise_abs == dnnl_eltwise_abs, "dnnl API mismatch");
84+
static_assert((int)eltwise_sqrt == dnnl_eltwise_sqrt, "dnnl API mismatch");
85+
static_assert((int)eltwise_linear == dnnl_eltwise_linear, "dnnl API mismatch");
86+
static_assert(
87+
(int)eltwise_soft_relu == dnnl_eltwise_soft_relu, "dnnl API mismatch");
88+
static_assert((int)eltwise_hardsigmoid == dnnl_eltwise_hardsigmoid,
89+
"dnnl API mismatch");
90+
static_assert(
91+
(int)eltwise_logistic == dnnl_eltwise_logistic, "dnnl API mismatch");
92+
static_assert((int)eltwise_exp == dnnl_eltwise_exp, "dnnl API mismatch");
93+
static_assert(
94+
(int)eltwise_gelu_tanh == dnnl_eltwise_gelu_tanh, "dnnl API mismatch");
95+
static_assert((int)eltwise_swish == dnnl_eltwise_swish, "dnnl API mismatch");
96+
static_assert((int)eltwise_log == dnnl_eltwise_log, "dnnl API mismatch");
97+
static_assert((int)eltwise_clip == dnnl_eltwise_clip, "dnnl API mismatch");
98+
static_assert(
99+
(int)eltwise_clip_v2 == dnnl_eltwise_clip_v2, "dnnl API mismatch");
100+
static_assert((int)eltwise_pow == dnnl_eltwise_pow, "dnnl API mismatch");
101+
static_assert(
102+
(int)eltwise_gelu_erf == dnnl_eltwise_gelu_erf, "dnnl API mismatch");
103+
static_assert((int)eltwise_round == dnnl_eltwise_round, "dnnl API mismatch");
104+
static_assert((int)eltwise_mish == dnnl_eltwise_mish, "dnnl API mismatch");
105+
static_assert(
106+
(int)eltwise_hardswish == dnnl_eltwise_hardswish, "dnnl API mismatch");
107+
108+
static_assert((int)eltwise_relu_dst == dnnl_eltwise_relu_use_dst_for_bwd,
109+
"dnnl API mismatch");
110+
static_assert((int)eltwise_tanh_dst == dnnl_eltwise_tanh_use_dst_for_bwd,
111+
"dnnl API mismatch");
112+
static_assert((int)eltwise_elu_dst == dnnl_eltwise_elu_use_dst_for_bwd,
113+
"dnnl API mismatch");
114+
static_assert((int)eltwise_sqrt_dst == dnnl_eltwise_sqrt_use_dst_for_bwd,
115+
"dnnl API mismatch");
116+
static_assert(
117+
(int)eltwise_logistic_dst == dnnl_eltwise_logistic_use_dst_for_bwd,
118+
"dnnl API mismatch");
119+
static_assert((int)eltwise_exp_dst == dnnl_eltwise_exp_use_dst_for_bwd,
120+
"dnnl API mismatch");
121+
static_assert((int)eltwise_clip_v2_dst == dnnl_eltwise_clip_v2_use_dst_for_bwd,
122+
"dnnl API mismatch");
123+
124+
// Binary algorithm kinds
125+
static_assert((int)binary_add == dnnl_binary_add, "dnnl API mismatch");
126+
static_assert((int)binary_mul == dnnl_binary_mul, "dnnl API mismatch");
127+
static_assert((int)binary_max == dnnl_binary_max, "dnnl API mismatch");
128+
static_assert((int)binary_min == dnnl_binary_min, "dnnl API mismatch");
129+
static_assert((int)binary_div == dnnl_binary_div, "dnnl API mismatch");
130+
static_assert((int)binary_sub == dnnl_binary_sub, "dnnl API mismatch");
131+
static_assert((int)binary_ge == dnnl_binary_ge, "dnnl API mismatch");
132+
static_assert((int)binary_gt == dnnl_binary_gt, "dnnl API mismatch");
133+
static_assert((int)binary_le == dnnl_binary_le, "dnnl API mismatch");
134+
static_assert((int)binary_lt == dnnl_binary_lt, "dnnl API mismatch");
135+
static_assert((int)binary_eq == dnnl_binary_eq, "dnnl API mismatch");
136+
static_assert((int)binary_ne == dnnl_binary_ne, "dnnl API mismatch");
137+
static_assert((int)binary_select == dnnl_binary_select, "dnnl API mismatch");
138+
139+
#endif
140+
141+
#ifndef __OPENCL_VERSION__
142+
}
143+
#endif
144+
145+
#endif

src/gpu/intel/include/eltwise.h

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef GPU_INTEL_INCLUDE_ELTWISE_H
1818
#define GPU_INTEL_INCLUDE_ELTWISE_H
1919

20+
#include "gpu/intel/include/dnnl_interop.h"
2021
#include "gpu/intel/include/types.h"
2122

2223
#if DT_F16 == 1
@@ -268,36 +269,44 @@ POST_OP_DATA_T hardswish_bwd(
268269
float fwd_eltwise_common(int eltwise_alg, POST_OP_DATA_T x, float alpha_,
269270
float beta_, float scale_) {
270271
switch (eltwise_alg) {
271-
case RELU: return scale_ * relu_fwd(x, alpha_); break;
272-
case LINEAR: return scale_ * linear_fwd(x, alpha_, beta_); break;
273-
case SOFT_RELU: return scale_ * soft_relu_fwd(x, alpha_); break;
274-
case MISH: return scale_ * mish_fwd(x); break;
275-
case LOGISTIC: return scale_ * logistic_fwd(x); break;
276-
case TANH: return scale_ * tanh_fwd(x); break;
277-
case ELU: return scale_ * elu_fwd(x, alpha_); break;
278-
case SQUARE: return scale_ * square_fwd(x); break;
279-
case SQRT: return scale_ * sqrt_fwd(x); break;
280-
case ABS: return scale_ * abs_fwd(x); break;
281-
case EXP: return scale_ * exp_fwd(x); break;
282-
case GELU_TANH: return scale_ * gelu_tanh_fwd(x); break;
283-
case SWISH: return scale_ * swish_fwd(x, alpha_); break;
284-
case LOG: return scale_ * log_fwd(x); break;
285-
case CLIP: return scale_ * clip_fwd(x, alpha_, beta_); break;
286-
case CLIP_V2: return scale_ * clip_v2_fwd(x, alpha_, beta_); break;
287-
case POW: return scale_ * pow_fwd(x, alpha_, beta_); break;
288-
case GELU_ERF: return scale_ * gelu_erf_fwd(x); break;
289-
case ROUND: return scale_ * round_fwd(x); break;
290-
case HARDSWISH: return scale_ * hardswish_fwd(x, alpha_, beta_); break;
291-
case HARDSIGMOID:
272+
case eltwise_relu: return scale_ * relu_fwd(x, alpha_); break;
273+
case eltwise_linear:
274+
return scale_ * linear_fwd(x, alpha_, beta_);
275+
break;
276+
case eltwise_soft_relu: return scale_ * soft_relu_fwd(x, alpha_); break;
277+
case eltwise_mish: return scale_ * mish_fwd(x); break;
278+
case eltwise_logistic: return scale_ * logistic_fwd(x); break;
279+
case eltwise_tanh: return scale_ * tanh_fwd(x); break;
280+
case eltwise_elu: return scale_ * elu_fwd(x, alpha_); break;
281+
case eltwise_square: return scale_ * square_fwd(x); break;
282+
case eltwise_sqrt: return scale_ * sqrt_fwd(x); break;
283+
case eltwise_abs: return scale_ * abs_fwd(x); break;
284+
case eltwise_exp: return scale_ * exp_fwd(x); break;
285+
case eltwise_gelu_tanh: return scale_ * gelu_tanh_fwd(x); break;
286+
case eltwise_swish: return scale_ * swish_fwd(x, alpha_); break;
287+
case eltwise_log: return scale_ * log_fwd(x); break;
288+
case eltwise_clip: return scale_ * clip_fwd(x, alpha_, beta_); break;
289+
case eltwise_clip_v2:
290+
return scale_ * clip_v2_fwd(x, alpha_, beta_);
291+
break;
292+
case eltwise_pow: return scale_ * pow_fwd(x, alpha_, beta_); break;
293+
case eltwise_gelu_erf: return scale_ * gelu_erf_fwd(x); break;
294+
case eltwise_round: return scale_ * round_fwd(x); break;
295+
case eltwise_hardswish:
296+
return scale_ * hardswish_fwd(x, alpha_, beta_);
297+
break;
298+
case eltwise_hardsigmoid:
292299
return scale_ * hardsigmoid_fwd(x, alpha_, beta_);
293300
break;
294-
case RELU_DST: return scale_ * relu_fwd(x, alpha_); break;
295-
case LOGISTIC_DST: return scale_ * logistic_fwd(x); break;
296-
case TANH_DST: return scale_ * tanh_fwd(x); break;
297-
case ELU_DST: return scale_ * elu_fwd(x, alpha_); break;
298-
case SQRT_DST: return scale_ * sqrt_fwd(x); break;
299-
case EXP_DST: return scale_ * exp_fwd(x); break;
300-
case CLIP_V2_DST: return scale_ * clip_v2_fwd(x, alpha_, beta_); break;
301+
case eltwise_relu_dst: return scale_ * relu_fwd(x, alpha_); break;
302+
case eltwise_logistic_dst: return scale_ * logistic_fwd(x); break;
303+
case eltwise_tanh_dst: return scale_ * tanh_fwd(x); break;
304+
case eltwise_elu_dst: return scale_ * elu_fwd(x, alpha_); break;
305+
case eltwise_sqrt_dst: return scale_ * sqrt_fwd(x); break;
306+
case eltwise_exp_dst: return scale_ * exp_fwd(x); break;
307+
case eltwise_clip_v2_dst:
308+
return scale_ * clip_v2_fwd(x, alpha_, beta_);
309+
break;
301310
default: return x; break;
302311
}
303312
}
@@ -314,33 +323,37 @@ float bwd_eltwise(
314323
POST_OP_DATA_T x, POST_OP_DATA_T y, float alpha_, float beta_) {
315324
#ifdef ELTWISE_ALG
316325
switch (ELTWISE_ALG) {
317-
case RELU: return relu_bwd(x, y, alpha_); break;
318-
case LINEAR: return linear_bwd(x, alpha_); break;
319-
case SOFT_RELU: return soft_relu_bwd(x, y, alpha_); break;
320-
case MISH: return mish_bwd(x, y); break;
321-
case LOGISTIC: return logistic_bwd(x, y); break;
322-
case TANH: return tanh_bwd(x, y); break;
323-
case ELU: return elu_bwd(x, y, alpha_); break;
324-
case SQUARE: return square_bwd(x, y); break;
325-
case SQRT: return sqrt_bwd(x, y); break;
326-
case ABS: return abs_bwd(x, y); break;
327-
case EXP: return exp_bwd(x, y); break;
328-
case GELU_TANH: return gelu_tanh_bwd(x, y); break;
329-
case SWISH: return swish_bwd(x, y, alpha_); break;
330-
case LOG: return log_bwd(x, y); break;
331-
case CLIP: return clip_bwd(x, y, alpha_, beta_); break;
332-
case CLIP_V2: return clip_v2_bwd(x, y, alpha_, beta_); break;
333-
case POW: return pow_bwd(x, y, alpha_, beta_); break;
334-
case GELU_ERF: return gelu_erf_bwd(x, y); break;
335-
case HARDSWISH: return hardswish_bwd(x, y, alpha_, beta_); break;
336-
case HARDSIGMOID: return hardsigmoid_bwd(x, y, alpha_, beta_); break;
337-
case RELU_DST: return relu_bwd_use_dst(x, y, alpha_); break;
338-
case LOGISTIC_DST: return logistic_bwd_use_dst(x, y); break;
339-
case TANH_DST: return tanh_bwd_use_dst(x, y); break;
340-
case ELU_DST: return elu_bwd_use_dst(x, y, alpha_); break;
341-
case SQRT_DST: return sqrt_bwd_use_dst(x, y); break;
342-
case EXP_DST: return exp_bwd_use_dst(x, y); break;
343-
case CLIP_V2_DST:
326+
case eltwise_relu: return relu_bwd(x, y, alpha_); break;
327+
case eltwise_linear: return linear_bwd(x, alpha_); break;
328+
case eltwise_soft_relu: return soft_relu_bwd(x, y, alpha_); break;
329+
case eltwise_mish: return mish_bwd(x, y); break;
330+
case eltwise_logistic: return logistic_bwd(x, y); break;
331+
case eltwise_tanh: return tanh_bwd(x, y); break;
332+
case eltwise_elu: return elu_bwd(x, y, alpha_); break;
333+
case eltwise_square: return square_bwd(x, y); break;
334+
case eltwise_sqrt: return sqrt_bwd(x, y); break;
335+
case eltwise_abs: return abs_bwd(x, y); break;
336+
case eltwise_exp: return exp_bwd(x, y); break;
337+
case eltwise_gelu_tanh: return gelu_tanh_bwd(x, y); break;
338+
case eltwise_swish: return swish_bwd(x, y, alpha_); break;
339+
case eltwise_log: return log_bwd(x, y); break;
340+
case eltwise_clip: return clip_bwd(x, y, alpha_, beta_); break;
341+
case eltwise_clip_v2: return clip_v2_bwd(x, y, alpha_, beta_); break;
342+
case eltwise_pow: return pow_bwd(x, y, alpha_, beta_); break;
343+
case eltwise_gelu_erf: return gelu_erf_bwd(x, y); break;
344+
case eltwise_hardswish:
345+
return hardswish_bwd(x, y, alpha_, beta_);
346+
break;
347+
case eltwise_hardsigmoid:
348+
return hardsigmoid_bwd(x, y, alpha_, beta_);
349+
break;
350+
case eltwise_relu_dst: return relu_bwd_use_dst(x, y, alpha_); break;
351+
case eltwise_logistic_dst: return logistic_bwd_use_dst(x, y); break;
352+
case eltwise_tanh_dst: return tanh_bwd_use_dst(x, y); break;
353+
case eltwise_elu_dst: return elu_bwd_use_dst(x, y, alpha_); break;
354+
case eltwise_sqrt_dst: return sqrt_bwd_use_dst(x, y); break;
355+
case eltwise_exp_dst: return exp_bwd_use_dst(x, y); break;
356+
case eltwise_clip_v2_dst:
344357
return clip_v2_bwd_use_dst(x, y, alpha_, beta_);
345358
break;
346359

0 commit comments

Comments
 (0)