| #pragma once |
| #include "vectorization.cuh" |
|
|
| namespace vllm { |
|
|
| template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp> |
| struct DefaultVecOp { |
| ScaOp scalar_op; |
|
|
| __device__ __forceinline__ void operator()( |
| vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const { |
| #pragma unroll |
| for (int i = 0; i < VEC_SIZE; ++i) { |
| scalar_op(dst.val[i], src.val[i]); |
| } |
| } |
| }; |
|
|
| template <int VEC_SIZE, typename InT, typename OutT, typename VecOp, |
| typename ScaOp> |
| __device__ inline void vectorize_with_alignment( |
| const InT* in, OutT* out, int len, int tid, int stride, |
| VecOp&& vec_op, |
| ScaOp&& scalar_op) { |
| static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, |
| "VEC_SIZE must be a positive power-of-two"); |
| constexpr int WIDTH = VEC_SIZE * sizeof(InT); |
| uintptr_t addr = reinterpret_cast<uintptr_t>(in); |
|
|
| int misalignment_offset = addr & (WIDTH - 1); |
| int alignment_bytes = WIDTH - misalignment_offset; |
| int prefix_elems = alignment_bytes & (WIDTH - 1); |
| prefix_elems /= sizeof(InT); |
| prefix_elems = min(prefix_elems, len); |
|
|
| |
| for (int i = tid; i < prefix_elems; i += stride) { |
| scalar_op(out[i], in[i]); |
| } |
|
|
| in += prefix_elems; |
| out += prefix_elems; |
| len -= prefix_elems; |
|
|
| int num_vec = len / VEC_SIZE; |
| using vin_t = vec_n_t<InT, VEC_SIZE>; |
| using vout_t = vec_n_t<OutT, VEC_SIZE>; |
| auto* v_in = reinterpret_cast<const vin_t*>(in); |
| auto* v_out = reinterpret_cast<vout_t*>(out); |
|
|
| |
| for (int i = tid; i < num_vec; i += stride) { |
| vout_t tmp; |
| vec_op(tmp, v_in[i]); |
| v_out[i] = tmp; |
| } |
|
|
| |
| int tail_start = num_vec * VEC_SIZE; |
| for (int i = tid + tail_start; i < len; i += stride) { |
| scalar_op(out[i], in[i]); |
| } |
| } |
|
|
| template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp> |
| __device__ __forceinline__ void vectorize_with_alignment(const InT* in, |
| OutT* out, int len, |
| int tid, int stride, |
| ScaOp&& scalar_op) { |
| using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>; |
| vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op}, |
| std::forward<ScaOp>(scalar_op)); |
| } |
|
|
| } |
|
|