1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| import torch from torch.utils.cpp_extension import load_inline
cuda_source = '''
__global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < height && col < width) { int idx = row * width + col; result[idx] = matrix[idx] * matrix[idx]; } }
#include <sys/types.h> #include <unistd.h>
torch::Tensor square_matrix(torch::Tensor matrix) { const auto height = matrix.size(0); const auto width = matrix.size(1); pid_t pid = getpid(); printf("pid %d " , pid); auto result = torch::empty_like(matrix);
dim3 threads_per_block(16, 16); dim3 number_of_blocks((width + threads_per_block.x - 1) / threads_per_block.x, (height + threads_per_block.y - 1) / threads_per_block.y);
square_matrix_kernel<<<number_of_blocks, threads_per_block>>>( matrix.data_ptr<float>(), result.data_ptr<float>(), width, height);
return result; } '''
cpp_source = "torch::Tensor square_matrix(torch::Tensor matrix);"
square_matrix_extension = load_inline( name='square_matrix_extension', cpp_sources=cpp_source, cuda_sources=cuda_source, functions=['square_matrix'], with_cuda=True, extra_cuda_cflags=["-O2"], build_directory='./load_inline_cuda', )
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]], device='cuda') print(square_matrix_extension.square_matrix(a))
|