diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index 967faaff7f9e..22bb942cb1dd 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -245,11 +245,21 @@ void TransformKernel(Context const* ctx, TensorView t, Fn&& fn) { } #endif +// vector-scalar multiplication +template +void VecScaMulFp32(Context const* ctx, linalg::VectorView x, float mul) { + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; }); +} + // vector-scalar multiplication template void VecScaMul(Context const* ctx, linalg::VectorView x, double mul) { CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal); - TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; }); + if (ctx->DeviceFP64() != ctx->Device()) { + VecScaMulFp32(ctx, x, mul); + } else { + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; }); + } } // vector-scalar division @@ -261,7 +271,11 @@ void VecScaDiv(Context const* ctx, linalg::VectorView x, double div) { template void LogE(Context const* ctx, linalg::VectorView x, float rt_eps = 0.0f) { CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal); +#if defined(SYCL_LANGUAGE_VERSION) + TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return ::sycl::log(v + rt_eps); }); +#else TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v + rt_eps); }); +#endif } template >* = nullptr>