Skip to content

Commit cf27491

Browse files
author
Dmitry Razdoburdin
committed
initial
1 parent 7991260 commit cf27491

File tree

8 files changed

+81
-14
lines changed

8 files changed

+81
-14
lines changed

plugin/sycl/common/optional_weight.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
#include "../device_manager.h"
99

1010
namespace xgboost::common::sycl_impl {
11-
double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
12-
sycl::DeviceManager device_manager;
13-
auto* qu = device_manager.GetQueue(ctx->Device());
1411

12+
template <typename T>
13+
T ElementWiseSum(::sycl::queue* qu, OptionalWeights const& weights) {
1514
const auto* data = weights.Data();
16-
double result = 0;
15+
T result = 0;
1716
{
18-
::sycl::buffer<double> buff(&result, 1);
17+
::sycl::buffer<T> buff(&result, 1);
1918
qu->submit([&](::sycl::handler& cgh) {
2019
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>());
2120
cgh.parallel_for<>(::sycl::range<1>(weights.Size()), reduction,
@@ -28,4 +27,16 @@ double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
2827

2928
return result;
3029
}
30+
31+
double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
32+
sycl::DeviceManager device_manager;
33+
auto* qu = device_manager.GetQueue(ctx->Device());
34+
35+
bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64);
36+
if (has_fp64_support) {
37+
return ElementWiseSum<double>(qu, weights);
38+
} else {
39+
return ElementWiseSum<float>(qu, weights);
40+
}
41+
}
3142
} // namespace xgboost::common::sycl_impl

plugin/sycl/common/stats.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*!
2+
* Copyright by Contributors 2017-2025
3+
*/
4+
#include <sycl/sycl.hpp>
5+
6+
#include "../../../src/common/stats.h"
7+
8+
#include "../device_manager.h"
9+
10+
namespace xgboost::common::sycl_impl {
11+
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out) {
12+
sycl::DeviceManager device_manager;
13+
auto* qu = device_manager.GetQueue(ctx->Device());
14+
15+
qu->submit([&](::sycl::handler& cgh) {
16+
auto reduction =
17+
::sycl::reduction(&(out(0)), 0.0f, ::sycl::plus<float>(),
18+
::sycl::property::reduction::initialize_to_identity());
19+
cgh.parallel_for<>(::sycl::range<1>(v.Size()), reduction,
20+
[=](::sycl::id<1> pid, auto& sum) {
21+
size_t i = pid[0];
22+
sum += v(i);
23+
});
24+
}).wait_and_throw();
25+
}
26+
}

plugin/sycl/context_helper.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ namespace sycl {
1414

1515
DeviceOrd DeviceFP64(const DeviceOrd& device) {
1616
DeviceManager device_manager;
17-
bool support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64);
17+
bool support_fp64 = true;
18+
if (device.IsSycl()) {
19+
support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64);
20+
}
1821
if (support_fp64) {
1922
return device;
2023
} else {

plugin/sycl/device_manager.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,25 @@ ::sycl::queue* DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
2121
(collective::IsDistributed());
2222
DeviceRegister& device_register = GetDevicesRegister();
2323
if (not_use_default_selector) {
24-
const int device_idx =
25-
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
2624
if (device_spec.IsSyclDefault()) {
2725
auto& devices = device_register.devices;
26+
const int device_idx = collective::IsDistributed()
27+
? collective::GetRank() % devices.size()
28+
: device_spec.ordinal;
2829
CHECK_LT(device_idx, devices.size());
2930
queue_idx = device_idx;
3031
} else if (device_spec.IsSyclCPU()) {
3132
auto& cpu_devices_idxes = device_register.cpu_devices_idxes;
33+
const int device_idx = collective::IsDistributed()
34+
? collective::GetRank() % cpu_devices_idxes.size()
35+
: device_spec.ordinal;
3236
CHECK_LT(device_idx, cpu_devices_idxes.size());
3337
queue_idx = cpu_devices_idxes[device_idx];
3438
} else if (device_spec.IsSyclGPU()) {
3539
auto& gpu_devices_idxes = device_register.gpu_devices_idxes;
40+
const int device_idx = collective::IsDistributed()
41+
? collective::GetRank() % gpu_devices_idxes.size()
42+
: device_spec.ordinal;
3643
CHECK_LT(device_idx, gpu_devices_idxes.size());
3744
queue_idx = gpu_devices_idxes[device_idx];
3845
} else {

src/common/linalg_op.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,13 @@ void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
159159
#elif defined(SYCL_LANGUAGE_VERSION)
160160
template <typename T, std::int32_t D, typename Fn, auto _tag = detail::SysTag()>
161161
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
162-
ctx->DispatchDevice([&] { cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
163-
[&] { LOG(FATAL) << "Invalid TU"; },
164-
[&] { ::xgboost::sycl::linalg::ElementWiseKernel(t, std::forward<Fn>(fn)); });
162+
if (t.Device().IsCPU()) {
163+
cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn));
164+
} else {
165+
ctx->DispatchDevice([&] { cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
166+
[&] { LOG(FATAL) << "Invalid TU"; },
167+
[&] { ::xgboost::sycl::linalg::ElementWiseKernel(t, std::forward<Fn>(fn)); });
168+
}
165169
}
166170
#else
167171
template <typename T, std::int32_t D, typename Fn, auto _tag = detail::SysTag()>

src/common/stats.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::Vector<
5151

5252
if (ctx->IsCUDA()) {
5353
cuda_impl::Mean(ctx, v, out->View(ctx->Device()));
54+
} else if (ctx->IsSycl()) {
55+
sycl_impl::Mean(ctx, v, out->View(ctx->Device()));
5456
} else {
5557
auto h_v = v;
5658
float n = v.Size();

src/common/stats.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "xgboost/linalg.h" // TensorView,VectorView
1515
#include "xgboost/logging.h" // CHECK_GE
1616

17-
#if !defined(XGBOOST_USE_CUDA)
17+
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_SYCL)
1818
#include "common.h" // AssertGPUSupport
1919
#endif
2020

@@ -140,6 +140,17 @@ inline void WeightedSampleMean(Context const*, bool, linalg::MatrixView<float co
140140
#endif // !defined(XGBOOST_USE_CUDA)
141141
} // namespace cuda_impl
142142

143+
namespace sycl_impl {
144+
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);
145+
146+
#if !defined(XGBOOST_USE_SYCL)
147+
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
148+
common::AssertGPUSupport();
149+
}
150+
151+
#endif // !defined(XGBOOST_USE_SYCL)
152+
} // namespace sycl_impl
153+
143154
/**
144155
* @brief Calculate medians for each column of the input matrix.
145156
*/

src/objective/multiclass_obj.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,13 @@ class SoftmaxMultiClassObj : public ObjFunction {
106106
<< "Number of weights should be equal to number of data points.";
107107
}
108108
info.weights_.SetDevice(device);
109-
auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_);
109+
auto weights = common::MakeOptionalWeights(device, info.weights_);
110110

111111
preds.SetDevice(device);
112-
auto predt = linalg::MakeTensorView(this->ctx_, &preds, n_samples, n_classes);
112+
Context cpu_context = Context();
113+
auto predt = linalg::MakeTensorView(
114+
device == ctx_->Device() ? this->ctx_ : &cpu_context,
115+
&preds, n_samples, n_classes);
113116
CHECK_EQ(labels.Shape(1), 1);
114117
auto y1d = labels.Slice(linalg::All(), 0);
115118
CHECK_EQ(y1d.Shape(0), info.num_row_);

0 commit comments

Comments
 (0)