#define CL_TARGET_OPENCL_VERSION 300
#include <CL/opencl.h>
#include <err.h>
#include <stdio.h>
#include <stdlib.h>

const char *incSource =
    "\n"
    "__kernel void inc(                      \n"
    "   __global double* device_a,           \n"
    "   const unsigned int size) {           \n"
    "   int i = get_global_id(0);            \n"
    "   if(i < size)                         \n"
    "       device_a[i]= device_a[i] + 1.0;  \n"
    "}                                       \n";

int main(int arg, char *argv[]) {
    const unsigned int size = 1 << 16;

    cl_int errCode;

    // Obtain the first available platform.
    cl_platform_id platformID = NULL;
    cl_uint numPlatforms;
    errCode = clGetPlatformIDs(1, &platformID, &numPlatforms);
    if (errCode != CL_SUCCESS) {
        errx(1, "clGetPlatformIDs() failed");
    }

    // Get size requirement for platform name
    size_t infoSize;
    clGetPlatformInfo(platformID, CL_PLATFORM_NAME, 0, NULL, &infoSize);

    // Get platform name
    char *platformName = (char *) malloc(infoSize);
    clGetPlatformInfo(platformID, CL_PLATFORM_NAME, infoSize, platformName, NULL);

    printf("OpenCL Platform name: %s\n", platformName);

    // Obtain the first available device on the platform
    cl_device_id deviceID = NULL;
    cl_uint numDevices;
    errCode = clGetDeviceIDs(platformID, CL_DEVICE_TYPE_DEFAULT, 1,
                             &deviceID, &numDevices);
    if (errCode != CL_SUCCESS) {
        errx(1, "clGetDeviceIDs() failed");
    }

    // Get size requirement for device name
    clGetDeviceInfo(deviceID, CL_DEVICE_NAME, 0, NULL, &infoSize);

    // Get device name
    char *deviceName = (char *) malloc(infoSize);
    clGetDeviceInfo(deviceID, CL_DEVICE_NAME, infoSize, deviceName, NULL);

    printf("OpenCL Device name: %s\n", deviceName);

    // Create an OpenCL context
    // Contexts are used by the OpenCL runtime for managing objects such as command-queues, memory, program and kernel objects
    cl_context context =
        clCreateContext(NULL, 1, &deviceID, NULL, NULL, &errCode);
    if (errCode != CL_SUCCESS) {
        errx(1, "clCreateContext() failed");
    }

    // Create a command queue
    cl_command_queue commandQueue =
        clCreateCommandQueueWithProperties(context, deviceID, NULL, &errCode);
    if (errCode != CL_SUCCESS) {
        errx(1, "clCreateCommandQueue() failed");
    }

    printf("* Allocate memory on the host\n");
    double *a = (double *) malloc(size * sizeof(double));
    if (a == NULL) {
        errx(1, "malloc a[] failed");
    }

    printf("* Pre-process / initialize data on the host\n");
    printf("  e.g. read data from storage\n");
    for (int i = 0; i < size; i++) {
        a[i] = 1.;
    }

    printf("* Allocate memory on the device\n");
    cl_mem device_a =
        clCreateBuffer(context, CL_MEM_READ_WRITE, size * sizeof(double), NULL, &errCode);
    if (errCode != CL_SUCCESS) {
        errx(1, "clCreateBuffer() failed");
    }

    printf("* Copy data from the host to the device\n");
    errCode = clEnqueueWriteBuffer(commandQueue, device_a, CL_TRUE, 0, size * sizeof(double), a, 0, NULL, NULL);
    if (errCode != CL_SUCCESS) {
        errx(1, "clEnqueueWriteBuffer() failed");
    }

    //
    printf("* Compute on the device\n");
    //

    // Creates a program object for a context, and loads source code specified by text strings into the program object
    cl_program program =
        clCreateProgramWithSource(context, 1, &incSource, NULL, &errCode);
    if (errCode != CL_SUCCESS) {
        errx(1, "clCreateProgramWithSource() failed");
    }

    // Builds (compiles and links) a program executable from the program source
    errCode = clBuildProgram(program, 1, &deviceID, NULL, NULL, NULL);
    if (errCode != CL_SUCCESS) {
        size_t len;
        char buffer[2048];
        clGetProgramBuildInfo(program, deviceID, CL_PROGRAM_BUILD_LOG, sizeof(buffer), buffer, &len);
        errx(1, "clBuildProgram() failed:\n%s", buffer);
    }

    // Creates a kernel object
    cl_kernel kernel =
        clCreateKernel(program, "inc", &errCode);
    if (errCode != CL_SUCCESS) {
        errx(1, "clCreateKernel() failed");
    }

    // Set the argument value for a specific argument of a kernel
    errCode = clSetKernelArg(kernel, 0, sizeof(cl_mem), &device_a);
    if (errCode != CL_SUCCESS) {
        errx(1, "clSetKernelArg() failed");
    }
    errCode = clSetKernelArg(kernel, 1, sizeof(unsigned int), &size);
    if (errCode != CL_SUCCESS) {
        errx(1, "clSetKernelArg() failed");
    }

    // Query the maximum workgroup size
    size_t local;
    errCode = clGetKernelWorkGroupInfo(kernel, deviceID, CL_KERNEL_WORK_GROUP_SIZE, sizeof(local), &local, NULL);
    if (errCode != CL_SUCCESS) {
        errx(1, "clGetKernelWorkGroupInfo() failed");
    }

    // Enqueues a command to execute a kernel on a device
    size_t global = size;
    errCode = clEnqueueNDRangeKernel(commandQueue, kernel, 1, NULL, &global, &local, 0, NULL, NULL);
    if (errCode != CL_SUCCESS) {
        errx(1, "clEnqueueNDRangeKernel() failed");
    }

    // Wait for command completion
    errCode = clFinish(commandQueue);
    if (errCode != CL_SUCCESS) {
        errx(1, "clFinish() failed");
    }

    // Release the kernel object
    errCode = clReleaseKernel(kernel);

    // Release the program object
    errCode = clReleaseProgram(program);

    // Release the device
    errCode = clReleaseDevice(deviceID);

    printf("* Transfer data back from the device to the host\n");
    errCode = clEnqueueReadBuffer(commandQueue, device_a, CL_TRUE, 0, size * sizeof(double), a, 0, NULL, NULL);
    if (errCode != CL_SUCCESS) {
        errx(1, "clEnqueueReadBuffer() failed");
    }

    printf("* Delete data on the device\n");
    errCode = clReleaseMemObject(device_a);
    if (errCode != CL_SUCCESS) {
        errx(1, "clReleaseMemObject() failed");
    }

    // Release a command queue
    errCode = clReleaseCommandQueue(commandQueue);

    // release the context
    errCode = clReleaseContext(context);

    printf("* Post-process data on the host\n");
    printf("  e.g. write data to storage\n");
    for (int i = 0; i < size; i++) {
        if (a[i] != 2.) {
            errx(2, "Computation on GPU failed");
        }
    }

    printf("* Free memory on the host\n");
    free(a);

    return 0;
}