1717#ifndef GPU_INTEL_INCLUDE_ELTWISE_H
1818#define GPU_INTEL_INCLUDE_ELTWISE_H
1919
20+ #include "gpu/intel/include/dnnl_interop.h"
2021#include "gpu/intel/include/types.h"
2122
2223#if DT_F16 == 1
@@ -268,36 +269,44 @@ POST_OP_DATA_T hardswish_bwd(
268269float fwd_eltwise_common (int eltwise_alg , POST_OP_DATA_T x , float alpha_ ,
269270 float beta_ , float scale_ ) {
270271 switch (eltwise_alg ) {
271- case RELU : return scale_ * relu_fwd (x , alpha_ ); break ;
272- case LINEAR : return scale_ * linear_fwd (x , alpha_ , beta_ ); break ;
273- case SOFT_RELU : return scale_ * soft_relu_fwd (x , alpha_ ); break ;
274- case MISH : return scale_ * mish_fwd (x ); break ;
275- case LOGISTIC : return scale_ * logistic_fwd (x ); break ;
276- case TANH : return scale_ * tanh_fwd (x ); break ;
277- case ELU : return scale_ * elu_fwd (x , alpha_ ); break ;
278- case SQUARE : return scale_ * square_fwd (x ); break ;
279- case SQRT : return scale_ * sqrt_fwd (x ); break ;
280- case ABS : return scale_ * abs_fwd (x ); break ;
281- case EXP : return scale_ * exp_fwd (x ); break ;
282- case GELU_TANH : return scale_ * gelu_tanh_fwd (x ); break ;
283- case SWISH : return scale_ * swish_fwd (x , alpha_ ); break ;
284- case LOG : return scale_ * log_fwd (x ); break ;
285- case CLIP : return scale_ * clip_fwd (x , alpha_ , beta_ ); break ;
286- case CLIP_V2 : return scale_ * clip_v2_fwd (x , alpha_ , beta_ ); break ;
287- case POW : return scale_ * pow_fwd (x , alpha_ , beta_ ); break ;
288- case GELU_ERF : return scale_ * gelu_erf_fwd (x ); break ;
289- case ROUND : return scale_ * round_fwd (x ); break ;
290- case HARDSWISH : return scale_ * hardswish_fwd (x , alpha_ , beta_ ); break ;
291- case HARDSIGMOID :
272+ case eltwise_relu : return scale_ * relu_fwd (x , alpha_ ); break ;
273+ case eltwise_linear :
274+ return scale_ * linear_fwd (x , alpha_ , beta_ );
275+ break ;
276+ case eltwise_soft_relu : return scale_ * soft_relu_fwd (x , alpha_ ); break ;
277+ case eltwise_mish : return scale_ * mish_fwd (x ); break ;
278+ case eltwise_logistic : return scale_ * logistic_fwd (x ); break ;
279+ case eltwise_tanh : return scale_ * tanh_fwd (x ); break ;
280+ case eltwise_elu : return scale_ * elu_fwd (x , alpha_ ); break ;
281+ case eltwise_square : return scale_ * square_fwd (x ); break ;
282+ case eltwise_sqrt : return scale_ * sqrt_fwd (x ); break ;
283+ case eltwise_abs : return scale_ * abs_fwd (x ); break ;
284+ case eltwise_exp : return scale_ * exp_fwd (x ); break ;
285+ case eltwise_gelu_tanh : return scale_ * gelu_tanh_fwd (x ); break ;
286+ case eltwise_swish : return scale_ * swish_fwd (x , alpha_ ); break ;
287+ case eltwise_log : return scale_ * log_fwd (x ); break ;
288+ case eltwise_clip : return scale_ * clip_fwd (x , alpha_ , beta_ ); break ;
289+ case eltwise_clip_v2 :
290+ return scale_ * clip_v2_fwd (x , alpha_ , beta_ );
291+ break ;
292+ case eltwise_pow : return scale_ * pow_fwd (x , alpha_ , beta_ ); break ;
293+ case eltwise_gelu_erf : return scale_ * gelu_erf_fwd (x ); break ;
294+ case eltwise_round : return scale_ * round_fwd (x ); break ;
295+ case eltwise_hardswish :
296+ return scale_ * hardswish_fwd (x , alpha_ , beta_ );
297+ break ;
298+ case eltwise_hardsigmoid :
292299 return scale_ * hardsigmoid_fwd (x , alpha_ , beta_ );
293300 break ;
294- case RELU_DST : return scale_ * relu_fwd (x , alpha_ ); break ;
295- case LOGISTIC_DST : return scale_ * logistic_fwd (x ); break ;
296- case TANH_DST : return scale_ * tanh_fwd (x ); break ;
297- case ELU_DST : return scale_ * elu_fwd (x , alpha_ ); break ;
298- case SQRT_DST : return scale_ * sqrt_fwd (x ); break ;
299- case EXP_DST : return scale_ * exp_fwd (x ); break ;
300- case CLIP_V2_DST : return scale_ * clip_v2_fwd (x , alpha_ , beta_ ); break ;
301+ case eltwise_relu_dst : return scale_ * relu_fwd (x , alpha_ ); break ;
302+ case eltwise_logistic_dst : return scale_ * logistic_fwd (x ); break ;
303+ case eltwise_tanh_dst : return scale_ * tanh_fwd (x ); break ;
304+ case eltwise_elu_dst : return scale_ * elu_fwd (x , alpha_ ); break ;
305+ case eltwise_sqrt_dst : return scale_ * sqrt_fwd (x ); break ;
306+ case eltwise_exp_dst : return scale_ * exp_fwd (x ); break ;
307+ case eltwise_clip_v2_dst :
308+ return scale_ * clip_v2_fwd (x , alpha_ , beta_ );
309+ break ;
301310 default : return x ; break ;
302311 }
303312}
@@ -314,33 +323,37 @@ float bwd_eltwise(
314323 POST_OP_DATA_T x , POST_OP_DATA_T y , float alpha_ , float beta_ ) {
315324#ifdef ELTWISE_ALG
316325 switch (ELTWISE_ALG ) {
317- case RELU : return relu_bwd (x , y , alpha_ ); break ;
318- case LINEAR : return linear_bwd (x , alpha_ ); break ;
319- case SOFT_RELU : return soft_relu_bwd (x , y , alpha_ ); break ;
320- case MISH : return mish_bwd (x , y ); break ;
321- case LOGISTIC : return logistic_bwd (x , y ); break ;
322- case TANH : return tanh_bwd (x , y ); break ;
323- case ELU : return elu_bwd (x , y , alpha_ ); break ;
324- case SQUARE : return square_bwd (x , y ); break ;
325- case SQRT : return sqrt_bwd (x , y ); break ;
326- case ABS : return abs_bwd (x , y ); break ;
327- case EXP : return exp_bwd (x , y ); break ;
328- case GELU_TANH : return gelu_tanh_bwd (x , y ); break ;
329- case SWISH : return swish_bwd (x , y , alpha_ ); break ;
330- case LOG : return log_bwd (x , y ); break ;
331- case CLIP : return clip_bwd (x , y , alpha_ , beta_ ); break ;
332- case CLIP_V2 : return clip_v2_bwd (x , y , alpha_ , beta_ ); break ;
333- case POW : return pow_bwd (x , y , alpha_ , beta_ ); break ;
334- case GELU_ERF : return gelu_erf_bwd (x , y ); break ;
335- case HARDSWISH : return hardswish_bwd (x , y , alpha_ , beta_ ); break ;
336- case HARDSIGMOID : return hardsigmoid_bwd (x , y , alpha_ , beta_ ); break ;
337- case RELU_DST : return relu_bwd_use_dst (x , y , alpha_ ); break ;
338- case LOGISTIC_DST : return logistic_bwd_use_dst (x , y ); break ;
339- case TANH_DST : return tanh_bwd_use_dst (x , y ); break ;
340- case ELU_DST : return elu_bwd_use_dst (x , y , alpha_ ); break ;
341- case SQRT_DST : return sqrt_bwd_use_dst (x , y ); break ;
342- case EXP_DST : return exp_bwd_use_dst (x , y ); break ;
343- case CLIP_V2_DST :
326+ case eltwise_relu : return relu_bwd (x , y , alpha_ ); break ;
327+ case eltwise_linear : return linear_bwd (x , alpha_ ); break ;
328+ case eltwise_soft_relu : return soft_relu_bwd (x , y , alpha_ ); break ;
329+ case eltwise_mish : return mish_bwd (x , y ); break ;
330+ case eltwise_logistic : return logistic_bwd (x , y ); break ;
331+ case eltwise_tanh : return tanh_bwd (x , y ); break ;
332+ case eltwise_elu : return elu_bwd (x , y , alpha_ ); break ;
333+ case eltwise_square : return square_bwd (x , y ); break ;
334+ case eltwise_sqrt : return sqrt_bwd (x , y ); break ;
335+ case eltwise_abs : return abs_bwd (x , y ); break ;
336+ case eltwise_exp : return exp_bwd (x , y ); break ;
337+ case eltwise_gelu_tanh : return gelu_tanh_bwd (x , y ); break ;
338+ case eltwise_swish : return swish_bwd (x , y , alpha_ ); break ;
339+ case eltwise_log : return log_bwd (x , y ); break ;
340+ case eltwise_clip : return clip_bwd (x , y , alpha_ , beta_ ); break ;
341+ case eltwise_clip_v2 : return clip_v2_bwd (x , y , alpha_ , beta_ ); break ;
342+ case eltwise_pow : return pow_bwd (x , y , alpha_ , beta_ ); break ;
343+ case eltwise_gelu_erf : return gelu_erf_bwd (x , y ); break ;
344+ case eltwise_hardswish :
345+ return hardswish_bwd (x , y , alpha_ , beta_ );
346+ break ;
347+ case eltwise_hardsigmoid :
348+ return hardsigmoid_bwd (x , y , alpha_ , beta_ );
349+ break ;
350+ case eltwise_relu_dst : return relu_bwd_use_dst (x , y , alpha_ ); break ;
351+ case eltwise_logistic_dst : return logistic_bwd_use_dst (x , y ); break ;
352+ case eltwise_tanh_dst : return tanh_bwd_use_dst (x , y ); break ;
353+ case eltwise_elu_dst : return elu_bwd_use_dst (x , y , alpha_ ); break ;
354+ case eltwise_sqrt_dst : return sqrt_bwd_use_dst (x , y ); break ;
355+ case eltwise_exp_dst : return exp_bwd_use_dst (x , y ); break ;
356+ case eltwise_clip_v2_dst :
344357 return clip_v2_bwd_use_dst (x , y , alpha_ , beta_ );
345358 break ;
346359
0 commit comments