From df801f118c7f00c729999e3373ea8d33ed9449a0 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Fri, 20 Feb 2026 02:40:12 -0800 Subject: [PATCH 1/2] fix for fp32 --- src/common/linalg_op.h | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index 967faaff7f9e..daaf3ffabff2 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -245,11 +245,21 @@ void TransformKernel(Context const* ctx, TensorView t, Fn&& fn) { } #endif +// vector-scalar multiplication +template +void VecScaMulFp32(Context const* ctx, linalg::VectorView x, float mul) { + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; }); +} + // vector-scalar multiplication template void VecScaMul(Context const* ctx, linalg::VectorView x, double mul) { CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal); - TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; }); + if (ctx->DeviceFP64() != ctx->Device()) { + VecScaMulFp32(ctx, x, mul); + } else { + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; }); + } } // vector-scalar division @@ -261,7 +271,11 @@ void VecScaDiv(Context const* ctx, linalg::VectorView x, double div) { template void LogE(Context const* ctx, linalg::VectorView x, float rt_eps = 0.0f) { CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal); - TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v + rt_eps); }); + #if defined(SYCL_LANGUAGE_VERSION) + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return ::sycl::log(v + rt_eps); }); + #else + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v + rt_eps); }); + #endif } template >* = nullptr> From 5a034e83ead37121a6c076d6c2029e5973c6e271 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Fri, 20 Feb 2026 02:44:24 -0800 Subject: [PATCH 2/2] precommit --- src/common/linalg_op.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index daaf3ffabff2..22bb942cb1dd 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -271,11 +271,11 @@ void VecScaDiv(Context const* ctx, linalg::VectorView x, double div) { template void LogE(Context const* ctx, linalg::VectorView x, float rt_eps = 0.0f) { CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal); - #if defined(SYCL_LANGUAGE_VERSION) - TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return ::sycl::log(v + rt_eps); }); - #else - TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v + rt_eps); }); - #endif +#if defined(SYCL_LANGUAGE_VERSION) + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return ::sycl::log(v + rt_eps); }); +#else + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v + rt_eps); }); +#endif } template >* = nullptr>