#include "render.hpp"
#include <cstdint>
#include <cassert>

struct rgb8_t {
  std::uint8_t r;
  std::uint8_t g;
  std::uint8_t b;
};

rgb8_t heat_lut(float x)
{
  assert(0 <= x && x <= 1);
  float x0 = 1.f / 4.f;
  float x1 = 2.f / 4.f;
  float x2 = 3.f / 4.f;

  if (x < x0)
  {
    auto g = static_cast<std::uint8_t>(x / x0 * 255);
    return rgb8_t{0, g, 255};
  }
  else if (x < x1)
  {
    auto b = static_cast<std::uint8_t>((x1 - x) / x0 * 255);
    return rgb8_t{0, 255, b};
  }
  else if (x < x2)
  {
    auto r = static_cast<std::uint8_t>((x - x1) / x0 * 255);
    return rgb8_t{r, 255, 0};
  }
  else if (x < 1)
  {
    auto b = static_cast<std::uint8_t>((1.f - x) / x0 * 255);
    return rgb8_t{255, b, 0};
  }
  else
  {
    return {0, 0, 0};
  }
}


void render_base(std::byte* buffer,
            int width,
            int height,
            std::ptrdiff_t stride,
            int n_iterations)
{
  for (int y = 0; y < height; ++y)
  {
    rgb8_t* lineptr = reinterpret_cast<rgb8_t*>(buffer);

    for (int x = 0; x < width; ++x)
      lineptr[x] = heat_lut((x * x + y * y) / float(width * width + height * height));

    buffer += stride;
  }
}


void render_mt(std::byte* buffer,
               int width,
               int height,
               std::ptrdiff_t stride,
               int n_iterations) {}

void render_optim(std::byte* buffer,
                  int width,
                  int height,
                  std::ptrdiff_t stride,
                  int n_iterations) {}

void render_simd(std::byte* buffer,
                 int width,
                 int height,
                 std::ptrdiff_t stride,
                 int n_iterations) {}

void render_final(std::byte* buffer,
                  int width,
                  int height,
                  std::ptrdiff_t stride,
                  int n_iterations) {}
