#include <iostream>
#include "red_black.hpp"
#include "cartesian.hpp"
#include <random>
#include <chrono>
#include <algorithm>
//using namespace red_black;

using uid = std::uniform_int_distribution<int>;

void stress_map() {
    const int OPERATIONS = 2'000'000;
    uid x_seg(1, 100'000);
    uid y_seg(1, 1000'000'000);
    uid v_seg(1, 1000'000'000);
    red_black::map::Tree rb;
    cartesian::map::Tree ca; std::chrono::high_resolution_clock::duration ;
    //std::mt19937 rng(std::chrono::high_resolution_clock::now().time_since_epoch().count());
    std::mt19937 rng;

    std::chrono::high_resolution_clock::duration
            ca_in{}, ca_fd{}, ca_re{}, ca_sz{}, rb_in{}, rb_fd{}, rb_re{}, rb_sz{};
    int32_t n_in = 0, n_fd = 0, n_re = 0, n_sz = 0;

    for (int op_index = 0; op_index < OPERATIONS; ++op_index) {
        int op_type = uid(0, 3)(rng);
        int x = x_seg(rng);
        int y = y_seg(rng);
        int v = v_seg(rng);
        if (op_type == 0) {
            auto t1 = std::chrono::high_resolution_clock::now();
            rb.insert(x, v);
            auto t2 = std::chrono::high_resolution_clock::now();
            ca.insert(x, v, y);
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_in;
            rb_in += t2 - t1;
            ca_in += t3 - t2;
        } else if (op_type == 1) {
            auto t1 = std::chrono::high_resolution_clock::now();
            std::optional<int> rbv = rb.find(x);
            auto t2 = std::chrono::high_resolution_clock::now();
            std::optional<int> cav = ca.find(x);
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_fd;
            rb_fd += t2 - t1;
            ca_fd += t3 - t2;
            if (rbv != cav)
                throw std::runtime_error(std::format(
                        "Failed find({}), {} != {}",
                        x, *rbv, *cav
                ));
        } else if (op_type == 2) {
            auto t1 = std::chrono::high_resolution_clock::now();
            std::optional<int> rbv = rb.remove(x);
            auto t2 = std::chrono::high_resolution_clock::now();
            std::optional<int> cav = ca.remove(x);
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_re;
            rb_re += t2 - t1;
            ca_re += t3 - t2;
            if (rbv != cav)
                throw std::runtime_error(std::format(
                        "Failed remove({}), {} != {}",
                        x, *rbv, *cav
                ));
        } else if (op_type == 3) {
            auto t1 = std::chrono::high_resolution_clock::now();
            size_t rbv = rb.size();
            auto t2 = std::chrono::high_resolution_clock::now();
            size_t cav = ca.size();
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_sz;
            rb_sz += t2 - t1;
            ca_sz += t3 - t2;
            if (rbv != cav)
                throw std::runtime_error(std::format(
                        "Failed size(), {} != {}",
                        rbv, cav
                ));
        }
    }
    std::cout << std::format(
            "Mean cart insert = {}\n", ca_in / n_in
    );
    std::cout << std::format(
            "Mean rb   insert = {}\n", rb_in / n_in
    );
    std::cout << std::format(
            "Mean cart find   = {}\n", ca_fd / n_fd
    );
    std::cout << std::format(
            "Mean rb   find   = {}\n", rb_fd / n_fd
    );
    std::cout << std::format(
            "Mean cart remove = {}\n", ca_re / n_re
    );
    std::cout << std::format(
            "Mean rb   remove = {}\n", rb_re / n_re
    );
    std::cout << std::format(
            "Mean cart size   = {}\n", ca_sz / n_sz
    );
    std::cout << std::format(
            "Mean rb   size   = {}\n", rb_sz / n_sz
    );
}

void stress_ordered_map() {
    const int OPERATIONS = 2'00'000;
    uid x_seg(1, 100'000);
    uid v_seg(1, 1'000'000'000);
    red_black::map::Tree rb;
    cartesian::ordered_map::Tree ca; std::chrono::high_resolution_clock::duration ;
    //std::mt19937 rng(std::chrono::high_resolution_clock::now().time_since_epoch().count());
    std::mt19937 rng;

    std::chrono::high_resolution_clock::duration
            ca_in{}, ca_fd{}, ca_re{}, ca_sz{}, rb_in{}, rb_fd{}, rb_re{}, rb_sz{};
    int32_t n_in = 0, n_fd = 0, n_re = 0, n_sz = 0;

    for (int op_index = 0; op_index < OPERATIONS; ++op_index) {
        int op_type = uid(0, 3)(rng);
        int x = x_seg(rng);
        int v = v_seg(rng);
        if (op_type == 0) {
            auto t1 = std::chrono::high_resolution_clock::now();
            rb.insert(x, v);
            auto t2 = std::chrono::high_resolution_clock::now();
            ca.insert(x, v);
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_in;
            rb_in += t2 - t1;
            ca_in += t3 - t2;
        } else if (op_type == 1) {
            auto t1 = std::chrono::high_resolution_clock::now();
            std::optional<int> rbv = rb.find(x);
            auto t2 = std::chrono::high_resolution_clock::now();
            std::optional<int> cav = ca.find(x);
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_fd;
            rb_fd += t2 - t1;
            ca_fd += t3 - t2;
            if (rbv != cav)
                throw std::runtime_error(std::format(
                        "Failed find({}), {} != {}",
                        x, *rbv, *cav
                ));
        } else if (op_type == 2) {
            auto t1 = std::chrono::high_resolution_clock::now();
            std::optional<int> rbv = rb.remove(x);
            auto t2 = std::chrono::high_resolution_clock::now();
            std::optional<int> cav = ca.remove(x);
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_re;
            rb_re += t2 - t1;
            ca_re += t3 - t2;
            if (rbv != cav)
                throw std::runtime_error(std::format(
                        "Failed remove({}), {} != {}",
                        x, *rbv, *cav
                ));
        } else if (op_type == 3) {
            auto t1 = std::chrono::high_resolution_clock::now();
            size_t rbv = rb.size();
            auto t2 = std::chrono::high_resolution_clock::now();
            size_t cav = ca.size();
            auto t3 = std::chrono::high_resolution_clock::now();
            ++n_sz;
            rb_sz += t2 - t1;
            ca_sz += t3 - t2;
            if (rbv != cav)
                throw std::runtime_error(std::format(
                        "Failed size(), {} != {}",
                        rbv, cav
                ));
        }
    }
    std::cout << std::format(
            "Mean cart insert = {}\n", ca_in / n_in
    );
    std::cout << std::format(
            "Mean rb   insert = {}\n", rb_in / n_in
    );
    std::cout << std::format(
            "Mean cart find   = {}\n", ca_fd / n_fd
    );
    std::cout << std::format(
            "Mean rb   find   = {}\n", rb_fd / n_fd
    );
    std::cout << std::format(
            "Mean cart remove = {}\n", ca_re / n_re
    );
    std::cout << std::format(
            "Mean rb   remove = {}\n", rb_re / n_re
    );
    std::cout << std::format(
            "Mean cart size   = {}\n", ca_sz / n_sz
    );
    std::cout << std::format(
            "Mean rb   size   = {}\n", rb_sz / n_sz
    );
}

int main() {
    stress_ordered_map();
    return 0;
}
