//
//  fractal_gpu.metal
//  Gaston
//
//  Created by Richard Kurz on 18/09/14.
//  Copyright (c) 2014 Richard Kurz (Public Domain). No rights reserved.
//

#include <metal_stdlib>

using namespace metal;


// Im realen Leben sollte hier nicht per Pixel, sondern zumindest zeilenweise gerechnet werden.
// Das erspart den Overhead des Aufrufs und der Parameterübergabe.

kernel void fractal(constant float* paramter [[buffer(0)]], device float* output [[buffer(1)]], uint index [[thread_position_in_grid]])
{
  int maxDepth = paramter[0];
  int pixWidth = paramter[1];
  int pixHeight = paramter[2];
  float left = paramter[3];
  float top = paramter[4];
  float width = paramter[5];
  float height = paramter[6];
  float a = paramter[7];
  float b = paramter[8];
  
  float x = left + width * (float(index % pixWidth) / pixWidth);
  float y = top + height * (float(index / pixWidth) / pixHeight);
  float x2, y2, r2;
  int depth = 0;
  
  do
  {
    y2 = y * y;
    x2 = x * x;
    r2 = x2 + y2;
    if (r2 > 4.f) break;
    y = x * y * 2.f + b;
    x = x2 - y2 + a;
  }
  while (depth++ < maxDepth);
  
  output[index] = float(depth) - log(log(r2) * 0.5f) / M_LN2_F;
  
  // Wegen log(r2) wird output NaN, wenn (depth >= maxDepth). Alternative:
  // if (depth >= maxDepth) output[index] = (float)maxDepth;
  // else output[index] = (float)depth - log(log(r2) * 0.5f) / M_LN2_F;
}


kernel void colorit(constant float* paramter [[buffer(0)]], constant float* input [[buffer(1)]], device uchar4* output [[buffer(2)]], uint index [[thread_position_in_grid]])
{
  float maxDepth = paramter[0];
  float depth = input[index];

  uchar r = (depth / maxDepth) * 255.f;
  uchar g = (cos(depth * 0.10f) + 1.f) * 127.f;
  uchar b = (sin(depth * 0.01f) + 1.f) * 127.f;
  
  output[index] = uchar4(r, g, b, 255);
}


// Nur als Beispiel, hier eine Version mit Vektorbreite 4.
// Ist allerdings unter Metal auf den PowerVR Chips deutlich langsamer
// da der Speicherdurchsatz hier praktisch keine Rolle spielt.
// Achtung: Der threadGroupCount.width Wert nuss durch 4 dividiert werden.

kernel void fractalV4(constant float* paramter [[buffer(0)]], device float4* output [[buffer(1)]], uint index [[thread_position_in_grid]])
{
  int4 maxDepth = paramter[0];
  int4 pixWidth = paramter[1];
  int4 pixHeight = paramter[2];
  float4 left = paramter[3];
  float4 top = paramter[4];
  float4 width = paramter[5];
  float4 height = paramter[6];
  float4 a = paramter[7];
  float4 b = paramter[8];
  
  int4 indexV = int4(index * 4) + int4(0, 1, 2, 3);
  
  float4 x = left + width * float4(indexV % pixWidth) / float4(pixWidth);
  float4 y = top + height * float4(indexV / pixWidth) / float4(pixHeight);
  float4 x2, y2, r2, lr2;
  int4 depth = 0;
  bool4 cont = true;
  
  do
  {
    y2 = y * y;
    x2 = x * x;
    y = x * y * 2.f + b;
    x = x2 - y2 + a;
    r2 = x2 + y2;
    lr2 = select(lr2, r2, cont);
    cont = (depth < maxDepth) && (r2 <= 4.f);
    depth += int4(cont);
  }
  while (any(cont));
  
  output[index] = float4(depth) - log(log(lr2) * 0.5f) / M_LN2_F;
}


// Ein sehr simples Beispiel

kernel void quadrieren(constant float* input [[buffer(0)]], device uchar4* output [[buffer(1)]], uint index [[thread_position_in_grid]])
{
  float n = input[index];
  output[index] = n * n;
}






