diff --git a/include/infinicore/nn/rmsnorm.hpp b/include/infinicore/nn/rmsnorm.hpp index 212b2a6e4..5891819eb 100644 --- a/include/infinicore/nn/rmsnorm.hpp +++ b/include/infinicore/nn/rmsnorm.hpp @@ -1,7 +1,7 @@ #pragma once -#include "module.hpp" #include "../ops.hpp" +#include "module.hpp" namespace infinicore::nn { @@ -57,6 +57,21 @@ class RMSNorm : public Module { */ Tensor forward(const Tensor &x) const; + /** + * @brief Forward pass: apply RMSNorm in-place with residual + * + * @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions. + * Will be modified in-place to the normalized output. + * @param residual Residual tensor to add to input before normalization. + * Will be modified in-place to the sum of input and residual. + * + * The normalization is applied over the last dimension. + * For example: + * Input: [batch, seq_len, hidden_size] -> normalize over hidden_size + * Input: [batch, hidden_size] -> normalize over hidden_size + */ + void forward_inplace(Tensor &x, Tensor &residual) const; + // Module information size_t normalized_shape() const { return normalized_shape_; } double eps() const { return eps_; } @@ -73,9 +88,9 @@ class RMSNorm : public Module { INFINICORE_NN_PARAMETER(weight); private: - size_t normalized_shape_; // Size of the feature dimension - double eps_; // Epsilon for numerical stability - DataType dtype_; // Data type for weight + size_t normalized_shape_; // Size of the feature dimension + double eps_; // Epsilon for numerical stability + DataType dtype_; // Data type for weight }; } // namespace infinicore::nn diff --git a/include/infinicore/ops/add_rms_norm.hpp b/include/infinicore/ops/add_rms_norm.hpp index e8a955a3c..50064e0a4 100644 --- a/include/infinicore/ops/add_rms_norm.hpp +++ b/include/infinicore/ops/add_rms_norm.hpp @@ -5,16 +5,14 @@ #include namespace infinicore::op { -class AddRMSNorm { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float); - static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float); // Fused Add and RMS Normalization // Returns: (normalized_result, add_result) // The add_result can be used as residual for subsequent layers -std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); -void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); +std::pair add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f); +void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f); +// Fused Add and RMS Normalization (inplace) +// normalized_result wil be stored in input, add_result will be stored in residual +void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f); } // namespace infinicore::op diff --git a/include/infiniop/ops/add_rms_norm.h b/include/infiniop/ops/add_rms_norm.h index 7742c1343..52cd096a6 100644 --- a/include/infiniop/ops/add_rms_norm.h +++ b/include/infiniop/ops/add_rms_norm.h @@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor( infiniopHandle_t handle, infiniopAddRMSNormDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc); + float epsilon); __C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size); @@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de void *workspace, size_t workspace_size, void *y, + void *residual_out, const void *a, const void *b, const void *weight, - void *residual_out, void *stream); __C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc); diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index c6b01d5aa..52a269ce5 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -43,7 +43,7 @@ uint8, ) from infinicore.ops.add import add -from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_ +from infinicore.ops.add_rms_norm import add_rms_norm from infinicore.ops.attention import attention from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul diff --git a/python/infinicore/ops/add_rms_norm.py b/python/infinicore/ops/add_rms_norm.py index 4ad347812..a5de7bd92 100644 --- a/python/infinicore/ops/add_rms_norm.py +++ b/python/infinicore/ops/add_rms_norm.py @@ -1,8 +1,8 @@ +import infinicore.tensor as tensor from infinicore.lib import _infinicore -from infinicore.tensor import Tensor -def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): +def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None, residual=None): """ Fused Add and RMS Normalization. @@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): The add_result can be used as residual for subsequent layers. """ if out is None: - result = _infinicore.add_rms_norm( - a._underlying, b._underlying, weight._underlying, epsilon - ) - return (Tensor(result[0]), Tensor(result[1])) + out = tensor.empty(a.shape, dtype=a.dtype, device=a.device) + if residual is None: + residual = tensor.empty(b.shape, dtype=b.dtype, device=b.device) - y, residual_out = out _infinicore.add_rms_norm_( - y._underlying, - residual_out._underlying, + out._underlying, + residual._underlying, a._underlying, b._underlying, weight._underlying, epsilon, ) - return (y, residual_out) - -def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5): - """In-place Fused Add and RMS Normalization.""" - _infinicore.add_rms_norm_( - y._underlying, - residual_out._underlying, - a._underlying, - b._underlying, - weight._underlying, - epsilon, - ) + return out, residual diff --git a/src/infinicore/nn/rmsnorm.cc b/src/infinicore/nn/rmsnorm.cc index a83c3a113..107dac44a 100644 --- a/src/infinicore/nn/rmsnorm.cc +++ b/src/infinicore/nn/rmsnorm.cc @@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const { return op::rms_norm(x, weight_, static_cast(eps_)); } +void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const { + if (!residual) { + residual = x; + x = op::rms_norm(x, weight_, static_cast(eps_)); + } else { + if (device_.getType() == Device::Type::CPU + || device_.getType() == Device::Type::NVIDIA + || device_.getType() == Device::Type::ILUVATAR + || device_.getType() == Device::Type::METAX + || device_.getType() == Device::Type::MOORE) { + op::add_rms_norm_inplace(x, residual, weight_, static_cast(eps_)); + } else { + op::add_(residual, x, residual); + op::rms_norm_(x, residual, weight_, static_cast(eps_)); + } + } +} + std::string RMSNorm::extra_repr() const { return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; } diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc index 650ce87e6..ccba62e21 100644 --- a/src/infinicore/ops/add_rms_norm/add_rms_norm.cc +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc @@ -4,26 +4,30 @@ namespace infinicore::op { -common::OpDispatcher &AddRMSNorm::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm); -void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { +AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon); } -std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { +void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon); +} + +std::pair add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device()); add_rms_norm_(y, residual_out, a, b, weight, epsilon); return std::make_pair(y, residual_out); } -void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { - AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon); +void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { + AddRMSNorm::execute(out, residual, a, b, weight, epsilon); +} + +void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) { + add_rms_norm_(input, residual, input, residual, weight, epsilon); } } // namespace infinicore::op diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc index d6540a039..53d30a2c7 100644 --- a/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc @@ -1,50 +1,53 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/add_rms_norm.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::add_rms_norm_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopAddRMSNormDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, residual, a, b, weight; + float epsilon; +}; -void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { +void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, AddRMSNorm, + seed, y->desc(), residual_out->desc(), + a->desc(), b->desc(), weight->desc(), epsilon); + + INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor); - auto desc_opt = cache.get(seed); - infiniopAddRMSNormDescriptor_t desc = nullptr; + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(y), + graph::GraphTensor(residual_out), + graph::GraphTensor(a), + graph::GraphTensor(b), + graph::GraphTensor(weight), + epsilon}; - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( - context::getInfiniopHandle(device), &desc, - y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return planned; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( - desc, workspace->data(), workspace_size, - y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); + planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), + planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - AddRMSNorm::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup); } // namespace infinicore::op::add_rms_norm_impl::infiniop diff --git a/src/infiniop/devices/metax/metax_kernel_common.h b/src/infiniop/devices/metax/metax_kernel_common.h index f81358d28..d850e9d04 100644 --- a/src/infiniop/devices/metax/metax_kernel_common.h +++ b/src/infiniop/devices/metax/metax_kernel_common.h @@ -8,8 +8,10 @@ // Posible maximum number of threads per block for METAX architectures // Used for picking correct kernel launch configuration -#define METAX_BLOCK_SIZE_1024 1024 #define METAX_BLOCK_SIZE_512 512 +#define METAX_BLOCK_SIZE_1024 1024 +#define METAX_BLOCK_SIZE_2048 2048 +#define METAX_BLOCK_SIZE_4096 4096 #define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess) diff --git a/src/infiniop/devices/moore/moore_kernel_common.h b/src/infiniop/devices/moore/moore_kernel_common.h index e0aea4148..d72cfb197 100644 --- a/src/infiniop/devices/moore/moore_kernel_common.h +++ b/src/infiniop/devices/moore/moore_kernel_common.h @@ -6,6 +6,7 @@ // Posible maximum number of threads per block for MUSA architectures // Used for picking correct kernel launch configuration +#define MOORE_BLOCK_SIZE_4096 4096 #define MOORE_BLOCK_SIZE_2048 2048 #define MOORE_BLOCK_SIZE_1024 1024 #define MOORE_BLOCK_SIZE_512 512 diff --git a/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh b/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh index f11643b42..02cee1ebf 100644 --- a/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh +++ b/src/infiniop/devices/nvidia/nvidia_kernel_common.cuh @@ -14,6 +14,7 @@ // Posible maximum number of threads per block for CUDA architectures // Used for picking correct kernel launch configuration #define CUDA_BLOCK_SIZE_4096 4096 +#define CUDA_BLOCK_SIZE_2048 2048 #define CUDA_BLOCK_SIZE_1024 1024 #define CUDA_BLOCK_SIZE_512 512 diff --git a/src/infiniop/ops/add_rms_norm/add_rms_norm.h b/src/infiniop/ops/add_rms_norm/add_rms_norm.h index c5d63333d..76451e982 100644 --- a/src/infiniop/ops/add_rms_norm/add_rms_norm.h +++ b/src/infiniop/ops/add_rms_norm/add_rms_norm.h @@ -33,19 +33,19 @@ infiniopHandle_t handle, \ Descriptor **desc_ptr, \ infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t residual_out_desc, \ infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t b_desc, \ infiniopTensorDescriptor_t weight_desc, \ - float epsilon, \ - infiniopTensorDescriptor_t residual_out_desc); \ + float epsilon); \ \ infiniStatus_t calculate( \ void *workspace, size_t workspace_size, \ void *y, \ + void *residual_out, \ const void *a, \ const void *b, \ const void *weight, \ - void *residual_out, \ void *stream) const; \ }; \ } diff --git a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc index 5e7954b71..a3099c5c4 100644 --- a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc +++ b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc @@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } template -infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) { +infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const T *w) { const size_t batch_size = info->shape[0]; const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1; const size_t dim = info->dim(); @@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T } template -infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) { +infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const Tw *w) { static_assert(std::is_same::value || std::is_same::value, "T must be fp16_t or bf16_t"); @@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { if (_info.atype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight)); } else if (_info.wtype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight)); } else if (_info.wtype == INFINI_DTYPE_BF16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } } else if (_info.atype == INFINI_DTYPE_BF16) { if (_info.wtype == INFINI_DTYPE_BF16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight)); } else if (_info.wtype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight)); } else if (_info.wtype == INFINI_DTYPE_F16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } } else if (_info.atype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out)); + CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (float *)residual_out, (const float *)a, (const float *)b, (const float *)weight)); } else if (_info.atype == INFINI_DTYPE_F64) { - CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out)); + CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (double *)residual_out, (const double *)a, (const double *)b, (const double *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/add_rms_norm/info.h b/src/infiniop/ops/add_rms_norm/info.h index abe1b5059..883aed343 100644 --- a/src/infiniop/ops/add_rms_norm/info.h +++ b/src/infiniop/ops/add_rms_norm/info.h @@ -16,9 +16,9 @@ class AddRMSNormInfo { float epsilon; std::vector shape; std::vector y_strides; + std::vector residual_out_strides; std::vector a_strides; std::vector b_strides; - std::vector residual_out_strides; bool has_residual_out; size_t ndim() const { return shape.size(); } @@ -26,11 +26,11 @@ class AddRMSNormInfo { static utils::Result create( infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { + float epsilon) { auto atype = y_desc->dtype(); auto wtype = weight_desc->dtype(); diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh new file mode 100644 index 000000000..3d6b13b53 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_METAX_CUH__ +#define __ADD_RMS_NORM_METAX_CUH__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca new file mode 100644 index 000000000..04355e927 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca @@ -0,0 +1,191 @@ +#include "../../../devices/metax/metax_common.h" +#include "add_rms_norm_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +// Kernel function template for add_rms_norm on Metax platform +template +INFINIOP_METAX_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::metax { + +// Internal opaque structure for Metax device handle +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor +Descriptor::~Descriptor() { + delete _opaque; +} + +// Create descriptor for add_rms_norm operator +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// Launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + hcStream_t stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + // Handle different data type combinations following Metax pattern + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__hpcc_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__hpcc_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +// Main calculation function +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { + + // Check workspace size + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + // Extract tensor strides and dimensions + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto stream_ = reinterpret_cast(stream); + + // Launch kernel with different block sizes + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream_)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream_)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream_)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream_)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::metax diff --git a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h new file mode 100644 index 000000000..9d3f810f2 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_MOORE_H__ +#define __ADD_RMS_NORM_MOORE_H__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(moore) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu new file mode 100644 index 000000000..90c027ead --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu @@ -0,0 +1,183 @@ +#include "../../../devices/moore/moore_common.h" +#include "add_rms_norm_moore.h" + +#include "../../../devices/moore/moore_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +// Kernel function template for add_rms_norm on Moore platform +template +INFINIOP_MOORE_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::moore { + +// Internal opaque structure for Moore device handle +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor +Descriptor::~Descriptor() { + delete _opaque; +} + +// Create descriptor for add_rms_norm operator +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// Launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + musaStream_t musa_stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + // Handle different data type combinations + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __mt_bfloat16, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__mt_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__mt_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +// Main calculation function +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { + + // Check workspace size + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + // Extract tensor strides and dimensions + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto musa_stream = reinterpret_cast(stream); + + // Launch kernel with appropriate block size based on device capability + if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::moore diff --git a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu index 03601205f..652f8adf3 100644 --- a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu +++ b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu @@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); auto info = result.take(); @@ -122,8 +122,8 @@ infiniStatus_t launchKernel( infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { if (workspace_size < _workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; @@ -143,7 +143,15 @@ infiniStatus_t Descriptor::calculate( auto cuda_stream = reinterpret_cast(stream); // launch kernel with different block sizes - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel( batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, @@ -151,8 +159,8 @@ infiniStatus_t Descriptor::calculate( a, stride_a_batch, stride_a_nhead, b, stride_b_batch, stride_b_nhead, weight, _info.wtype, _info.epsilon, cuda_stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { - CHECK_STATUS(launchKernel( + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, residual_out, stride_residual_out_batch, stride_residual_out_nhead, diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc index a856e5447..62187cf34 100644 --- a/src/infiniop/ops/add_rms_norm/operator.cc +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -17,12 +17,10 @@ // #include "bang/add_rms_norm_bang.h" #endif #ifdef ENABLE_METAX_API -// TODO: Add Metax implementation -// #include "metax/add_rms_norm_metax.cuh" +#include "metax/add_rms_norm_metax.cuh" #endif #ifdef ENABLE_MOORE_API -// TODO: Add Moore implementation -// #include "moore/add_rms_norm_moore.h" +#include "moore/add_rms_norm_moore.h" #endif #ifdef ENABLE_KUNLUN_API // TODO: Add Kunlun implementation @@ -32,12 +30,12 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( infiniopHandle_t handle, infiniopAddRMSNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { + float epsilon) { #define CREATE(CASE, NAMESPACE) \ case CASE: \ @@ -45,11 +43,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( handle, \ reinterpret_cast(desc_ptr), \ y_desc, \ + residual_out_desc, \ a_desc, \ b_desc, \ weight_desc, \ - epsilon, \ - residual_out_desc) + epsilon) switch (handle->device) { #ifdef ENABLE_CPU_API @@ -61,6 +59,12 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -94,6 +98,12 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -116,16 +126,16 @@ __C infiniStatus_t infiniopAddRMSNorm( void *workspace, size_t workspace_size, void *y, + void *residual_out, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) { #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream) + ->calculate(workspace, workspace_size, y, residual_out, a, b, weight, stream) switch (desc->device_type) { @@ -138,6 +148,12 @@ __C infiniStatus_t infiniopAddRMSNorm( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -173,6 +189,12 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip #ifdef ENABLE_ILUVATAR_API DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia); #endif diff --git a/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu b/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu index 6dae5af61..6e671df1b 100644 --- a/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu +++ b/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu @@ -76,7 +76,15 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, const void *x, void *stream_) const { cudaStream_t stream = (cudaStream_t)stream_; - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, + _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, + _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel( y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); @@ -84,10 +92,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, CHECK_STATUS(launchKernel( y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { - CHECK_STATUS(launchKernel( - y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, - _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } diff --git a/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu b/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu index b083650d4..21cda3695 100644 --- a/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu +++ b/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu @@ -117,12 +117,14 @@ infiniStatus_t Descriptor::calculate( auto cuda_stream = reinterpret_cast(stream); // launch kernel with different block sizes - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { - CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } diff --git a/test/infinicore/ops/add_rms_norm.py b/test/infinicore/ops/add_rms_norm.py index 429d9df25..f6bf165a9 100644 --- a/test/infinicore/ops/add_rms_norm.py +++ b/test/infinicore/ops/add_rms_norm.py @@ -30,8 +30,24 @@ ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (2048, 8192, 1), + (2048, 8192, 1), + (2048, 8192, 1), + ), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (16384, 4096, 1), + (16384, 4096, 1), + (16384, 4096, 1), + ), ] # Tolerance configuration @@ -87,12 +103,14 @@ def parse_test_cases(): y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype) # Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result) - residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) + residual_out_spec = TensorSpec.from_tensor( + a_shape, a_strides, input_dtype + ) test_cases.append( TestCase( inputs=[a_spec, b_spec, w_spec], kwargs={"epsilon": _EPSILON}, - output_specs=[y_spec, residual_out_spec], # Two outputs + output_specs=None, # Two outputs comparison_target=None, tolerance=tolerance, output_count=2, # Two outputs: normalized_result and add_result @@ -101,19 +119,25 @@ def parse_test_cases(): ) # Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w)) - if y_supports_inplace: - residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) - test_cases.append( - TestCase( - inputs=[a_spec, b_spec, w_spec], - kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)}, - output_specs=[y_spec, residual_out_spec], # Two outputs - comparison_target="out", - tolerance=tolerance, - output_count=2, - description=f"AddRMSNorm - INPLACE(out)", - ) - ) + # if y_supports_inplace: + # residual_out_spec = TensorSpec.from_tensor( + # a_shape, a_strides, input_dtype + # ) + # test_cases.append( + # TestCase( + # inputs=[a_spec, b_spec, w_spec], + # kwargs={ + # "epsilon": _EPSILON, + # "out": y_spec, + # "residual": residual_out_spec, + # }, + # output_specs=[y_spec, residual_out_spec], # Two outputs + # comparison_target="out", + # tolerance=tolerance, + # output_count=2, + # description=f"AddRMSNorm - INPLACE(out)", + # ) + # ) return test_cases @@ -127,7 +151,9 @@ def __init__(self): def get_test_cases(self): return parse_test_cases() - def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + def torch_operator( + self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs + ): """PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)""" input_dtype = a.dtype @@ -144,21 +170,19 @@ def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): add_result = sum_tensor.to(input_dtype) if out is not None: - # For in-place operations, we need to handle the output tuple - if isinstance(out, (tuple, list)) and len(out) == 2: - out[0].copy_(normalized_result) - out[1].copy_(add_result) - return tuple(out) - else: - # Single output - just return normalized result for backward compatibility - out.copy_(normalized_result) - return out - + out.copy_(normalized_result) + if residual is not None: + residual.copy_(add_result) + return (normalized_result, add_result) - def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + def infinicore_operator( + self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs + ): """InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)""" - return infinicore.add_rms_norm(a, b, weight, epsilon, out=out) + return infinicore.add_rms_norm( + a, b, weight, epsilon, out=out, residual=residual + ) def main(): diff --git a/test/infiniop/add_rms_norm.py b/test/infiniop/add_rms_norm.py index 930314761..e3b4f9b64 100644 --- a/test/infiniop/add_rms_norm.py +++ b/test/infiniop/add_rms_norm.py @@ -32,8 +32,24 @@ ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (2048, 8192, 1), + (2048, 8192, 1), + (2048, 8192, 1), + ), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (16384, 4096, 1), + (16384, 4096, 1), + (16384, 4096, 1), + ), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None), ] @@ -97,7 +113,9 @@ def test( w = TestTensor(w_shape, None, w_dtype, device) eps = 1e-6 - add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps) + add_rms_norm( + y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps + ) if sync is not None: sync() @@ -109,11 +127,11 @@ def test( handle, ctypes.byref(descriptor), y.descriptor, + residual_out.descriptor, a.descriptor, b.descriptor, w.descriptor, eps, - residual_out.descriptor, ) ) @@ -136,10 +154,10 @@ def lib_add_rms_norm(): workspace.data(), workspace_size.value, y.data(), + residual_out.data(), a.data(), b.data(), w.data(), - residual_out.data(), None, ) ) @@ -147,18 +165,22 @@ def lib_add_rms_norm(): lib_add_rms_norm() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - + # Verify normalized result (y) if DEBUG: debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) - + # Verify add result (residual_out) - should be a + b - expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32) + expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to( + torch.float32 + ) expected_residual = expected_residual.to(a.torch_tensor().dtype) if DEBUG: debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) - assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) + assert torch.allclose( + residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol + ) # Profiling workflow if PROFILE: diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 618be2b05..7d6cf17e2 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -393,6 +393,7 @@ def add_rms_norm_(lib): infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, c_float, ] @@ -412,6 +413,7 @@ def add_rms_norm_(lib): c_void_p, c_void_p, c_void_p, + c_void_p, ] lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32 diff --git a/xmake.lua b/xmake.lua index d5a4ba7f7..a51435325 100644 --- a/xmake.lua +++ b/xmake.lua @@ -114,6 +114,12 @@ option("iluvatar-gpu") set_description("Whether to compile implementations for Iluvatar GPU") option_end() +option("ivcore-20") + set_default(false) + set_showmenu(true) + set_description("Use ivcore20") +option_end() + if has_config("iluvatar-gpu") then add_defines("ENABLE_ILUVATAR_API") includes("xmake/iluvatar.lua") diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index 35ccf2154..57a935f4f 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -44,6 +44,9 @@ target("infiniop-iluvatar") set_warnings("all", "error") add_cuflags("-Wno-error=unused-private-field") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) + if has_config("ivcore-20") then + add_cuflags("--cuda-gpu-arch=ivcore20", {force = true}) + end add_culdflags("-fPIC") add_cxflags("-fPIC")