Created
June 16, 2020 06:49
-
-
Save YashasSamaga/c694859eff9bcc596611abb85eaeb673 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cuda_runtime.h> | |
#include <iostream> | |
#include <algorithm> | |
#include <random> | |
__global__ void relu(float* output, const float* input, unsigned int* sign32, int n) | |
{ | |
int i = blockIdx.x * blockDim.x + threadIdx.x; | |
bool sign = 0; | |
if (i < n) | |
{ | |
auto inp = input[i]; | |
sign = inp > 0; | |
output[i] = sign ? inp : 0; | |
} | |
unsigned predicate = __ballot_sync(0xFFFFFFFF, sign); | |
if (threadIdx.x % 32 == 0) | |
sign32[i / 32] = predicate; | |
} | |
__global__ void relu_grad(float* dz, const float* input, int n) | |
{ | |
int i = blockIdx.x * blockDim.x + threadIdx.x; | |
if (i < n) | |
dz[i] = input[i] > 0; | |
} | |
template <int N> | |
__global__ void relu_grad_fast(float* dz, const unsigned int* sign32, int n) | |
{ | |
int i = blockIdx.x * blockDim.x + threadIdx.x; | |
unsigned int predicate = __brev(__ldg(&sign32[i / (32 / N)])); | |
if (i < n) | |
{ | |
const int laneid_byN = threadIdx.x % (32 / N) * N; | |
if (N == 4) | |
{ | |
float4 dy; | |
dy.x = (predicate & (0x80000000 >> laneid_byN)) != 0; | |
dy.y = (predicate & (0x80000000 >> (laneid_byN + 1))) != 0; | |
dy.z = (predicate & (0x80000000 >> (laneid_byN + 2))) != 0; | |
dy.w = (predicate & (0x80000000 >> (laneid_byN + 3))) != 0; | |
reinterpret_cast<float4*>(dz)[i] = dy; | |
} | |
else if (N == 1) | |
{ | |
dz[i] = (predicate & (0x80000000 >> laneid_byN)) != 0; | |
} | |
else | |
{ | |
static_assert(N == 4 || N == 1, ""); | |
} | |
} | |
} | |
int main () | |
{ | |
constexpr int N = 1024 * 1024 * 16; | |
float *input; | |
unsigned int *sign32; | |
float* output; | |
cudaMalloc(&input, N * sizeof(float)); | |
cudaMalloc(&sign32, N * sizeof(unsigned int) / 32); // does not handle N that is not multiple of 32 | |
cudaMalloc(&output, N * sizeof(float)); | |
float *input_h = new float[N]; | |
float *output_h = new float[N]; | |
float *output_ref = new float[N]; | |
std::random_device rd; | |
std::mt19937 gen(rd()); | |
std::uniform_real_distribution<float> dis(-50, 50); | |
for (int i = 0; i < N; i++) | |
{ | |
double x = dis(gen); | |
input_h[i] = x; | |
output_ref[i] = x > 0; | |
} | |
cudaMemcpy(input, input_h, N * sizeof(float), cudaMemcpyHostToDevice); | |
relu<<<1024 * 16, 1024>>>(output, input, sign32, N); | |
cudaMemset(output, 0, N); | |
relu_grad<<<1024 * 16, 1024>>>(output, input, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] relu_grad: " << std::equal(output_h, output_h + N, output_ref) << '\n'; | |
{ | |
constexpr int BLOCK_SIZE = 1024; | |
cudaMemset(output, 0, N); | |
relu_grad_fast<1><<<1024 * 16, BLOCK_SIZE>>>(output, sign32, N); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec1] relu_grad_fast: " << std::equal(output_h, output_h + N, output_ref) << '\n'; | |
} | |
{ | |
constexpr int BLOCK_SIZE = 256; | |
cudaMemset(output, 0, N); | |
relu_grad_fast<4><<<1024 * 16, BLOCK_SIZE>>>(output, sign32, N / 4); | |
cudaDeviceSynchronize(); | |
cudaMemcpy(output_h, output, N * sizeof(float), cudaMemcpyDeviceToHost); | |
std::cout << "[vec4] relu_grad_fast: " << std::equal(output_h, output_h + N, output_ref) << '\n'; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment