Dhruv Bansal
Dhruv Bansal

Reputation: 41

Using Compute Shader in OpenGL ES with kotlin

I want to Test out how the shaders and compute shaders work.

The compute shader should just color a pixel white and return it. and then the shader should use that color to paint the bottom of the screen.

the shader works fine, but when I tried to implement compute shader, it just does not work.

Take a look at this implementation of renderer.

these are the shaders

compute shader

#version 310 es
precision highp float;
precision highp image2D;

layout (local_size_x = 1, local_size_y = 1) in; // run only for one pixel
layout (rgba32f, binding = 0) uniform writeonly image2D outputImage;

void main() {
    ivec2 pixelCoords = ivec2(gl_GlobalInvocationID.xy);
    imageStore(outputImage, pixelCoords, vec4(1.0)); // just color it white and return
}

fragment shader

#ifdef GL_FRAGMENT_PRECISION_HIGH
precision highp float;
#else
precision mediump float;
#endif

// 'varying' values are passed from the vertex shader to the fragment shader
varying vec2 vTexCoord;

uniform vec3 Accelerometer;// I removed the part of code that puts this data for simplicity
uniform sampler2D ColorSet;// single pixel texture returned from compute shader

void main() {
    if (gl_FragCoord.y > 250.0) {
        float x = Accelerometer[0] / 20.0;
        float y = Accelerometer[1] / 20.0;
        float z = Accelerometer[2] / 20.0;
        gl_FragColor = vec4(x + 0.5, y + 0.5, z + 0.5, 1.0);
    } else {
        gl_FragColor = texture2D(ColorSet, vec2(1.0));
    }
}

I think this is correct but please harass me if I am wrong, the problem should be the implementation of the renderer

Renderer

class MyGLRenderer(
) : GLSurfaceView.Renderer {
    private var computeShaderProgram = 0
    private var program = 0

    private val uniformLocations: MutableList<MutableMap<String, Int?>> = mutableListOf()
    private val uniformValues: MutableMap<String, Any> = mutableMapOf()

    private var textures = IntArray(1)// only one texture
    private var resolution = floatArrayOf(0f,0f)

    override fun onSurfaceCreated(unused: GL10?, config: EGLConfig?) {

        createTextures()
        vertexShader = createVertexShader()
        computeProgram = createComputeProgram()
        renderProgram = createRenderProgram()
    }

    override fun onDrawFrame(gl: GL10?) {
        GLES31.glClear(GLES31.GL_COLOR_BUFFER_BIT or GLES31.GL_DEPTH_BUFFER_BIT)

        dispatchComputeShader()
        readAndLogPixelFromTexture(0, Pair(0,0))
        drawTexture()
    }
}
    private fun createTextures() {
        GLES31.glGenTextures(1, textures, 0)
        GLES31.glBindTexture(GLES31.GL_TEXTURE_2D, textures[0])

// Set texture parameters
        GLES31.glTexParameteri(GLES31.GL_TEXTURE_2D, GLES31.GL_TEXTURE_MIN_FILTER, GLES31.GL_LINEAR)
        GLES31.glTexParameteri(GLES31.GL_TEXTURE_2D, GLES31.GL_TEXTURE_MAG_FILTER, GLES31.GL_LINEAR)
        GLES31.glTexParameteri(GLES31.GL_TEXTURE_2D, GLES31.GL_TEXTURE_WRAP_S, GLES31.GL_CLAMP_TO_EDGE)
        GLES31.glTexParameteri(GLES31.GL_TEXTURE_2D, GLES31.GL_TEXTURE_WRAP_T, GLES31.GL_CLAMP_TO_EDGE)

// Allocate storage for the texture with the correct format and type
        GLES31.glTexImage2D(
            GLES31.GL_TEXTURE_2D,
            0,
            GLES31.GL_RGBA8,       // Use RGBA32F for floating-point support
            1,                       // Texture width
            1,                       // Texture height
            0,
            GLES31.GL_RGBA,
            GLES31.GL_FLOAT,         // Ensure FLOAT format for the shader compatibility
            null                     // No initial data; just allocate storage
        )
// Unbind the texture
        GLES31.glBindTexture(GLES31.GL_TEXTURE_2D, 0)
    }

private fun createVertexShader(): Int {
        return compileShader(GLES31.GL_VERTEX_SHADER,vertexCode)
    }

private fun createComputeProgram(): Int {
        return GLES31.glCreateShader(GLES31.GL_COMPUTE_SHADER).also { shader ->
            GLES31.glShaderSource(shader,computeCode)
            GLES31.glCompileShader(shader)

            val compileStatus = IntArray(1)
            GLES31.glGetShaderiv(shader, GLES31.GL_COMPILE_STATUS, compileStatus, 0)

            if (compileStatus[0] == 0) {
                val infoLog = GLES31.glGetShaderInfoLog(shader)
                GLES31.glDeleteShader(shader)
                throw RuntimeException("Compute Shader compilation failed: $infoLog")
            }
        }
    }
private fun createRenderProgram(): Int {

        val fragmentShader = compileShader(GLES31.GL_FRAGMENT_SHADER,rendererCode)
        return GLES31.glCreateProgram().also {
            GLES31.glAttachShader(it, vertexShader)
            GLES31.glAttachShader(it, fragmentShader)
            GLES31.glLinkProgram(it)

            // Check for linking errors
            val linkStatus = IntArray(1)
            GLES31.glGetProgramiv(it, GLES31.GL_LINK_STATUS, linkStatus, 0)
            if (linkStatus[0] == 0) {
                GLES31.glDeleteProgram(it)
                Timber.tag("renderer").d("not linked successfully")
                throw RuntimeException("Could not link program: ${GLES31.glGetProgramInfoLog(it)}")
            }
            else{
                Timber.tag("renderer").d("linked successfully")
            }
        }
    }
private fun compileShader(type: Int, shaderCode: String): Int {
        return GLES31.glCreateShader(type).also { shader ->
            GLES31.glShaderSource(shader, shaderCode)
            GLES31.glCompileShader(shader)

            // Check for compilation errors
            val compiled = IntArray(1)
            GLES31.glGetShaderiv(shader, GLES31.GL_COMPILE_STATUS, compiled, 0)
            if (compiled[0] == 0) {
                Timber.e("Shader compilation failed: %s", GLES31.glGetShaderInfoLog(shader))
                GLES31.glDeleteShader(shader)
            }
        }
    }
fun dispatchComputeShader() {
        GLES31.glUseProgram(computeProgram)
        GLES31.glBindImageTexture(0, textures[0], 0, false, 0, GLES31.GL_WRITE_ONLY, GLES31.GL_RGBA32F)

        // Assuming resolution width and height are divisible by 16 or 1 (for testing)
        GLES31.glDispatchCompute(1, 1, 1)
        GLES31.glMemoryBarrier(GLES31.GL_SHADER_IMAGE_ACCESS_BARRIER_BIT)
    }
fun readAndLogPixelFromTexture(texture: Int, pixel:Pair<Int, Int>) {
        val pixelData = FloatArray(4)
        GLES31.glBindTexture(GLES31.GL_TEXTURE_2D, textures[texture])
        GLES31.glReadPixels(pixel.first, pixel.second, 1, 1, GLES31.GL_RGBA, GLES31.GL_FLOAT, java.nio.FloatBuffer.wrap(pixelData))

        // Log the color value to verify it
        val r = pixelData[0]
        val g = pixelData[1]
        val b = pixelData[2]
        val a = pixelData[3]
        Timber.d("SET-COLOR: R: $r, G: $g, B: $b, A: $a")
    }
private fun drawTexture() {
        GLES31.glUseProgram(renderProgram)

        val positionHandle = GLES31.glGetAttribLocation(renderProgram, "aPosition")
        GLES31.glEnableVertexAttribArray(positionHandle)

        val textureCoordHandle = GLES31.glGetAttribLocation(renderProgram, "aTexCoord")
        GLES31.glEnableVertexAttribArray(textureCoordHandle)

        val textureHandle = GLES31.glGetUniformLocation(renderProgram, "ColorSet")
        GLES31.glActiveTexture(GLES31.GL_TEXTURE0)
        GLES31.glBindTexture(GLES31.GL_TEXTURE_2D, textures[0])
        GLES31.glUniform1i(textureHandle, 0)

        // Example vertices and texture coordinates
        val vertices = floatArrayOf(
            -1.0f, 1.0f,  // top left
            -1.0f, -1.0f, // bottom left
            1.0f, -1.0f,  // bottom right
            1.0f, 1.0f    // top right
        )

        val textureCoords = floatArrayOf(
            0.0f, 0.0f,   // top left
            0.0f, 1.0f,   // bottom left
            1.0f, 1.0f,   // bottom right
            1.0f, 0.0f    // top right
        )

        val vertexBuffer = ByteBuffer.allocateDirect(vertices.size * 4).run {
            order(ByteOrder.nativeOrder())
            asFloatBuffer().apply {
                put(vertices)
                position(0)
            }
        }

        val textureCoordBuffer = ByteBuffer.allocateDirect(textureCoords.size * 4).run {
            order(ByteOrder.nativeOrder())
            asFloatBuffer().apply {
                put(textureCoords)
                position(0)
            }
        }

        GLES31.glVertexAttribPointer(
            positionHandle, 2,
            GLES31.GL_FLOAT, false,
            0, vertexBuffer
        )

        GLES31.glVertexAttribPointer(
            textureCoordHandle, 2,
            GLES31.GL_FLOAT, false,
            0, textureCoordBuffer
        )

        GLES31.glDrawArrays(GLES31.GL_TRIANGLE_FAN, 0, 4)

        GLES31.glDisableVertexAttribArray(positionHandle)
        GLES31.glDisableVertexAttribArray(textureCoordHandle)
    }

The bottom of the screen should be white with this setup but it is stuck at black.

I tried logging the color of the only pixel of the texture and it also turned out to be RGBA(0,0,0,0)

Can any one help me out.

Upvotes: 1

Views: 155

Answers (1)

Max Lebedev
Max Lebedev

Reputation: 47

First of all, you must instead of GLES31.glTexImage2D create immutable texture to be able a compute shader to write on it:

    GLES31.glTexStorage2D(
        GLES31.GL_TEXTURE_2D,
        1,                       // 1 level
        GLES31.GL_RGBA32F,
        1,                       // Texture width
        1,                       // Texture height
    )

Secondly, referring to this answer https://stackoverflow.com/a/53993894/7167920 ES version of OpenGL does not support using texture image data directly, you should use FrameBuffer.

Initialize it when you init:

frameBuffers = IntArray(1)
GLES20.glGenFramebuffers(1, frameBuffers, 0)

And here the working code of readAndLogPixelFromTexture():

checkGLError { GLES31.glBindFramebuffer(GLES31.GL_FRAMEBUFFER, frameBuffers[0]) }
checkGLError {
        GLES31.glFramebufferTexture2D(
            GLES31.GL_FRAMEBUFFER,
            GLES31.GL_COLOR_ATTACHMENT0,
            GLES31.GL_TEXTURE_2D,
            textures[texture],
            0
        )
    }
if (GLES31.glCheckFramebufferStatus(GLES31.GL_FRAMEBUFFER) != GLES31.GL_FRAMEBUFFER_COMPLETE) {
   throw RuntimeException("Framebuffer is not complete")
}
val pixelData = FloatArray(4)
GLES31.glReadPixels(pixel.first, pixel.second, 1, 1, GLES31.GL_RGBA, GLES31.GL_FLOAT, FloatBuffer.wrap(pixelData))
checkGLError { GLES31.glBindFramebuffer(GLES31.GL_FRAMEBUFFER, 0) }

// Log the color value to verify it
val r = pixelData[0]
val g = pixelData[1]
val b = pixelData[2]
val a = pixelData[3]
Log.d(TAG, "Compute shader result: R: $r, G: $g, B: $b, A: $a")

Upvotes: 0

Related Questions