Reputation: 173
how can I calculate mean and variance value of an image with 16 channels using Metal ?
I want to calculate mean and variance value of different channel sperately!
ex.:
kernel void meanandvariance(texture2d_array<float, access::read> in[[texture(0)]],
texture2d_array<float, access::write> out[[texture(1)]],
ushort3 gid[[thread_position_in_grid]],
ushort tid[[thread_index_in_threadgroup]],
ushort3 tg_size[[threads_per_threadgroup]]) {
}
Upvotes: 2
Views: 942
Reputation: 31782
There's probably a way to do this by creating a sequence of texture views on the input texture array and output texture array, encoding a MPSImageStatisticsMeanAndVariance
kernel invocation for each slice.
But let's take a look at how to do it ourselves. There are many different possible approaches, so I chose one that was simple and used some interesting results from statistics.
Essentially, we'll do the following:
Here are the kernels:
kernel void compute_row_mean_variance_array(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 tpig [[thread_position_in_grid]])
{
uint row = tpig.x;
uint slice = tpig.y;
uint width = inTexture.get_width();
if (row >= inTexture.get_height() || slice >= inTexture.get_array_size()) { return; }
float4 mean(0.0f);
float4 var(0.0f);
for (uint col = 0; col < width; ++col) {
float4 rgba = inTexture.read(ushort2(col, row), slice);
// http://datagenetics.com/blog/november22017/index.html
float weight = 1.0f / (col + 1);
float4 oldMean = mean;
mean = mean + (rgba - mean) * weight;
var = var + (rgba - oldMean) * (rgba - mean);
}
var = var / width;
outTexture.write(mean, ushort2(row, 0), slice);
outTexture.write(var, ushort2(row, 1), slice);
}
kernel void reduce_mean_variance_array(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 tpig [[thread_position_in_grid]])
{
uint width = inTexture.get_width();
uint slice = tpig.x;
// https://arxiv.org/pdf/1007.1012.pdf
float4 mean(0.0f);
float4 meanOfVar(0.0f);
float4 varOfMean(0.0f);
for (uint col = 0; col < width; ++col) {
float weight = 1.0f / (col + 1);
float4 oldMean = mean;
float4 submean = inTexture.read(ushort2(col, 0), slice);
mean = mean + (submean - mean) * weight;
float4 subvar = inTexture.read(ushort2(col, 1), slice);
meanOfVar = meanOfVar + (subvar - meanOfVar) * weight;
varOfMean = varOfMean + (submean - oldMean) * (submean - mean);
}
float4 var = meanOfVar + varOfMean / width;
outTexture.write(mean, ushort2(0, 0), slice);
outTexture.write(var, ushort2(1, 0), slice);
}
In summary, to achieve step 1, we use an "online" (incremental) algorithm to calculate the partial mean/variance of the row in a way that's more numerically-stable than just adding all the pixel values and dividing by the width. My reference for writing this kernel was this post. Each thread in the grid writes its row's statistics to the appropriate column and slice of an intermediate texture array.
To achieve step 2, we need to find a statistically-sound way of computing the overall statistics from the partial results. This is quite simple in the case of finding the mean: the mean of the population is the mean of the means of the subsets (this holds when the sample size of each subset is the same; in the general case, the overall mean is a weighted sum of the subset means). The variance is trickier, but it turns out that the variance of the population is the sum of the mean of the variances of the subsets and the variance of the means of the subsets (the same caveat about equally-sized subsets applies here). This is a convenient fact that we can combine with our incremental approach above to produce the final mean and variance of each slice, which is written to the corresponding slice of the output texture.
For completeness, here's the Swift code I used to drive these kernels:
let library = device.makeDefaultLibrary()!
let meanVarKernelFunction = library.makeFunction(name: "compute_row_mean_variance_array")!
let meanVarComputePipelineState = try! device.makeComputePipelineState(function: meanVarKernelFunction)
let reduceKernelFunction = library.makeFunction(name: "reduce_mean_variance_array")!
let reduceComputePipelineState = try! device.makeComputePipelineState(function: reduceKernelFunction)
let width = sourceTexture.width
let height = sourceTexture.height
let arrayLength = sourceTexture.arrayLength
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba32Float, width: width, height: height, mipmapped: false)
textureDescriptor.textureType = .type2DArray
textureDescriptor.arrayLength = arrayLength
textureDescriptor.width = height
textureDescriptor.height = 2
textureDescriptor.usage = [.shaderRead, .shaderWrite]
let partialResultsTexture = device.makeTexture(descriptor: textureDescriptor)!
textureDescriptor.width = 2
textureDescriptor.height = 1
textureDescriptor.usage = .shaderWrite
let destTexture = device.makeTexture(descriptor: textureDescriptor)!
let commandBuffer = commandQueue.makeCommandBuffer()!
let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder()!
computeCommandEncoder.setComputePipelineState(meanVarComputePipelineState)
computeCommandEncoder.setTexture(sourceTexture, index: 0)
computeCommandEncoder.setTexture(partialResultsTexture, index: 1)
let meanVarGridSize = MTLSize(width: sourceTexture.height, height: sourceTexture.arrayLength, depth: 1)
let meanVarThreadgroupSize = MTLSizeMake(meanVarComputePipelineState.threadExecutionWidth, 1, 1)
let meanVarThreadgroupCount = MTLSizeMake((meanVarGridSize.width + meanVarThreadgroupSize.width - 1) / meanVarThreadgroupSize.width,
(meanVarGridSize.height + meanVarThreadgroupSize.height - 1) / meanVarThreadgroupSize.height,
1)
computeCommandEncoder.dispatchThreadgroups(meanVarThreadgroupCount, threadsPerThreadgroup: meanVarThreadgroupSize)
computeCommandEncoder.setComputePipelineState(reduceComputePipelineState)
computeCommandEncoder.setTexture(partialResultsTexture, index: 0)
computeCommandEncoder.setTexture(destTexture, index: 1)
let reduceThreadgroupSize = MTLSizeMake(1, 1, 1)
let reduceThreadgroupCount = MTLSizeMake(arrayLength, 1, 1)
computeCommandEncoder.dispatchThreadgroups(reduceThreadgroupCount, threadsPerThreadgroup: reduceThreadgroupSize)
computeCommandEncoder.endEncoding()
let destTexture2DDesc = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba32Float, width: 2, height: 1, mipmapped: false)
destTexture2DDesc.usage = .shaderWrite
let destTexture2D = device.makeTexture(descriptor: destTexture2DDesc)!
meanVarKernel.encode(commandBuffer: commandBuffer, sourceTexture: sourceTexture2D, destinationTexture: destTexture2D)
#if os(macOS)
let blitCommandEncoder = commandBuffer.makeBlitCommandEncoder()!
blitCommandEncoder.synchronize(resource: destTexture)
blitCommandEncoder.synchronize(resource: destTexture2D)
blitCommandEncoder.endEncoding()
#endif
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
In my experiments, this program produced the same results as MPSImageStatisticsMeanAndVariance
, give or take some differences on the order of 1e-7. It was also 2.5x slower than MPS on my Mac, probably due in part to failure to exploit latency hiding with granular parallelism.
Upvotes: 3
Reputation: 173
#include <metal_stdlib>
using namespace metal;
kernel void instance_norm(constant float4* scale[[buffer(0)]],
constant float4* shift[[buffer(1)]],
texture2d_array<float, access::read> in[[texture(0)]],
texture2d_array<float, access::write> out[[texture(1)]],
ushort3 gid[[thread_position_in_grid]],
ushort tid[[thread_index_in_threadgroup]],
ushort3 tg_size[[threads_per_threadgroup]]) {
ushort width = in.get_width();
ushort height = in.get_height();
const ushort thread_count = tg_size.x * tg_size.y;
threadgroup float4 shared_mem [256];
float4 sum = 0;
for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
sum += in.read(ushort2(xIndex, yIndex), gid.z);
}
}
shared_mem[tid] = sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduce to 32 values
sum = 0;
if (tid < 32) {
for (ushort i = tid + 32; i < thread_count; i += 32) {
sum += shared_mem[i];
}
}
shared_mem[tid] += sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Calculate mean
sum = 0;
if (tid == 0) {
ushort top = min(ushort(32), thread_count);
for (ushort i = 0; i < top; i += 1) {
sum += shared_mem[i];
}
shared_mem[0] = sum / (width * height);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float4 mean = shared_mem[0];
threadgroup_barrier(mem_flags::mem_threadgroup);
// Variance
sum = 0;
for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
sum += pow(in.read(ushort2(xIndex, yIndex), gid.z) - mean, 2);
}
}
shared_mem[tid] = sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduce to 32 values
sum = 0;
if (tid < 32) {
for (ushort i = tid + 32; i < thread_count; i += 32) {
sum += shared_mem[i];
}
}
shared_mem[tid] += sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Calculate variance
sum = 0;
if (tid == 0) {
ushort top = min(ushort(32), thread_count);
for (ushort i = 0; i < top; i += 1) {
sum += shared_mem[i];
}
shared_mem[0] = sum / (width * height);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float4 sigma = sqrt(shared_mem[0] + float4(1e-4));
float4 multiplier = scale[gid.z] / sigma;
for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
float4 val = in.read(ushort2(xIndex, yIndex), gid.z);
out.write(clamp((val - mean) * multiplier + shift[gid.z], -10.0, 10.0), ushort2(xIndex, yIndex), gid.z);
}
}
}
this is how Blend implement, but I do not think it is true, can anybody solve it ?
https://github.com/xmartlabs/Bender/blob/master/Sources/Metal/instanceNorm.metal
Upvotes: 0