#include <iostream>
#include <random>
#include <algorithm>
#include <chrono>
#include "red_black.hpp"
#include "cartesian.hpp"
#include "smart_arr.hpp"

// using namespace cartesian::set;
using uid = std::uniform_int_distribution<int>;

void stress_set() {
    std::mt19937 rng(123);
    cartesian::set::Tree cart;
    red_black::set::Tree rb;
    const int OPERATIONS = 2'000'000;
    uid x_seg(0, 100000);
    uid y_seg(0, 1'000'000'000);

    std::chrono::high_resolution_clock::duration
            cart_ins{},
            cart_find{},
            cart_rem{},
            cart_sz{},
            rb_ins{},
            rb_find{},
            rb_sz{},
            rb_rem{};
    int32_t n_ins = 0, n_find = 0, n_rem = 0, n_sz = 0;

    for (int oper_index = 0; oper_index < OPERATIONS; ++oper_index) {
        int operation_type = uid(0, 3)(rng);
        int x = x_seg(rng);
        int y = y_seg(rng);
        std::chrono::high_resolution_clock::time_point t1, t2, t3;
        if (operation_type == 0) {
            //std::cout << std::format("insert ({}, {})\n", x, y);
            t1 = std::chrono::high_resolution_clock::now();
            cart.insert(x, y);
            t2 = std::chrono::high_resolution_clock::now();
            rb.insert(x);
            t3 = std::chrono::high_resolution_clock::now();
            cart_ins += t2 - t1;
            rb_ins += t3 - t2;
            ++n_ins;
        } else if (operation_type == 1) {
            t1 = std::chrono::high_resolution_clock::now();
            bool cart_res = cart.remove(x);
            t2 = std::chrono::high_resolution_clock::now();
            bool rb_res = rb.remove(x);
            t3 = std::chrono::high_resolution_clock::now();
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 1!!!");
            cart_rem += t2 - t1;
            rb_rem += t3 - t2;
            ++n_rem;
        } else if (operation_type == 2) {
            //std::cout << std::format("find ({})\n", x);
            t1 = std::chrono::high_resolution_clock::now();
            bool cart_res = cart.find(x);
            t2 = std::chrono::high_resolution_clock::now();
            bool rb_res = rb.find(x);
            t3 = std::chrono::high_resolution_clock::now();
            //std::cout << std::format("cart = {}, rb = {}\n", cart_res, rb_res);
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 2!!!");
            cart_find += t2 - t1;
            rb_find += t3 - t2;
            ++n_find;
        } else if (operation_type == 3) {
            t1 = std::chrono::high_resolution_clock::now();
            int cart_res = cart.size();
            t2 = std::chrono::high_resolution_clock::now();
            int rb_res = rb.size();
            t3 = std::chrono::high_resolution_clock::now();
            //std::cout << std::format("cart = {}, rb = {}\n", cart_res, rb_res);
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 3!!!");
            cart_sz += t2 - t1;
            rb_sz += t3 - t2;
            ++n_sz;
        }
    }
    std::cout << std::format(
            "Insert in cart took {}\n", cart_ins / n_ins
    );
    std::cout << std::format(
            "Insert in rb   took {}\n", rb_ins / n_ins
    );
    std::cout << std::format(
            "Find   in cart took {}\n", cart_find / n_find
    );
    std::cout << std::format(
            "Find   in rb   took {}\n", rb_find / n_find
    );
    std::cout << std::format(
            "Remove in cart took {}\n", cart_rem / n_rem
    );
    std::cout << std::format(
            "Remove in rb   took {}\n", rb_rem / n_rem
    );
    std::cout << std::format(
            "Size   in cart took {}\n", cart_sz / n_sz
    );
    std::cout << std::format(
            "Size   in rb   took {}\n", rb_sz / n_sz
    );
}

void stress_ordered_set() {
    std::mt19937 rng(123);
    cartesian::ordered_set::Tree cart;
    red_black::set::Tree rb;
    const int OPERATIONS = 2'000'000;
    uid x_seg(0, 100000);

    std::chrono::high_resolution_clock::duration
            cart_ins{},
            cart_find{},
            cart_rem{},
            cart_sz{},
            rb_ins{},
            rb_find{},
            rb_sz{},
            rb_rem{};
    int32_t n_ins = 0, n_find = 0, n_rem = 0, n_sz = 0;

    for (int oper_index = 0; oper_index < OPERATIONS; ++oper_index) {
        int operation_type = uid(0, 3)(rng);
        int x = x_seg(rng);
        std::chrono::high_resolution_clock::time_point t1, t2, t3;
        if (operation_type == 0) {
            //std::cout << std::format("insert ({}, {})\n", x, y);
            t1 = std::chrono::high_resolution_clock::now();
            cart.insert(x);
            t2 = std::chrono::high_resolution_clock::now();
            rb.insert(x);
            t3 = std::chrono::high_resolution_clock::now();
            cart_ins += t2 - t1;
            rb_ins += t3 - t2;
            ++n_ins;
        } else if (operation_type == 1) {
            t1 = std::chrono::high_resolution_clock::now();
            bool cart_res = cart.remove(x);
            t2 = std::chrono::high_resolution_clock::now();
            bool rb_res = rb.remove(x);
            t3 = std::chrono::high_resolution_clock::now();
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 1!!!");
            cart_rem += t2 - t1;
            rb_rem += t3 - t2;
            ++n_rem;
        } else if (operation_type == 2) {
            //std::cout << std::format("find ({})\n", x);
            t1 = std::chrono::high_resolution_clock::now();
            bool cart_res = cart.find(x);
            t2 = std::chrono::high_resolution_clock::now();
            bool rb_res = rb.find(x);
            t3 = std::chrono::high_resolution_clock::now();
            //std::cout << std::format("cart = {}, rb = {}\n", cart_res, rb_res);
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 2!!!");
            cart_find += t2 - t1;
            rb_find += t3 - t2;
            ++n_find;
        } else if (operation_type == 3) {
            t1 = std::chrono::high_resolution_clock::now();
            int cart_res = cart.size();
            t2 = std::chrono::high_resolution_clock::now();
            int rb_res = rb.size();
            t3 = std::chrono::high_resolution_clock::now();
            //std::cout << std::format("cart = {}, rb = {}\n", cart_res, rb_res);
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 3!!!");
            cart_sz += t2 - t1;
            rb_sz += t3 - t2;
            ++n_sz;
        }
    }
    std::cout << std::format(
            "Insert in cart took {}\n", cart_ins / n_ins
    );
    std::cout << std::format(
            "Insert in rb   took {}\n", rb_ins / n_ins
    );
    std::cout << std::format(
            "Find   in cart took {}\n", cart_find / n_find
    );
    std::cout << std::format(
            "Find   in rb   took {}\n", rb_find / n_find
    );
    std::cout << std::format(
            "Remove in cart took {}\n", cart_rem / n_rem
    );
    std::cout << std::format(
            "Remove in rb   took {}\n", rb_rem / n_rem
    );
    std::cout << std::format(
            "Size   in cart took {}\n", cart_sz / n_sz
    );
    std::cout << std::format(
            "Size   in rb   took {}\n", rb_sz / n_sz
    );
}

void stress_smart_array() {
    std::mt19937 rng(123);
    cartesian::smart_array::Tree cart;
    smart_arr::smart_array::Tree rb;
    const int OPERATIONS = 5'000'000;
    uid x_seg(0, 100000);

    std::chrono::high_resolution_clock::duration
            cart_ins{},
            cart_find{},
            cart_rem{},
            cart_sz{},
            rb_ins{},
            rb_find{},
            rb_sz{},
            rb_rem{};
    int32_t n_ins = 0, n_find = 0, n_rem = 0, n_sz = 0;

    for (int oper_index = 0; oper_index < OPERATIONS; ++oper_index) {
        int operation_type = uid(0,4)(rng);
        int x = x_seg(rng);
        std::chrono::high_resolution_clock::time_point t1, t2, t3;
        if (operation_type == 0 || operation_type == 4) {
            int pos = uid(0, cart.size())(rng);
            //std::cout << std::format("insert ({}, {})\n", x, y);
            t1 = std::chrono::high_resolution_clock::now();
            cart.insert(pos, x);
            t2 = std::chrono::high_resolution_clock::now();
            rb.insert(pos, x);
            t3 = std::chrono::high_resolution_clock::now();
            cart_ins += t2 - t1;
            rb_ins += t3 - t2;
            ++n_ins;
        } else if (operation_type == 1) {
            if (cart.size() == 0) continue;
            int pos = uid(0, cart.size() - 1)(rng);
            t1 = std::chrono::high_resolution_clock::now();
            char cart_res = cart.remove(pos);
            t2 = std::chrono::high_resolution_clock::now();
            char rb_res = rb.remove(pos);
            t3 = std::chrono::high_resolution_clock::now();
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 1!!!");
            cart_rem += t2 - t1;
            rb_rem += t3 - t2;
            ++n_rem;
        } else if (operation_type == 2) {
            if (cart.size() == 0) continue;
            int pos = uid(0, cart.size() - 1)(rng);
            //std::cout << std::format("find ({})\n", x);
            t1 = std::chrono::high_resolution_clock::now();
            char cart_res = cart.at(pos);
            t2 = std::chrono::high_resolution_clock::now();
            char rb_res = rb.at(pos);
            t3 = std::chrono::high_resolution_clock::now();
            //std::cout << std::format("cart = {}, rb = {}\n", cart_res, rb_res);
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 2!!!");
            cart_find += t2 - t1;
            rb_find += t3 - t2;
            ++n_find;
        } else if (operation_type == 3) {
            t1 = std::chrono::high_resolution_clock::now();
            int cart_res = cart.size();
            t2 = std::chrono::high_resolution_clock::now();
            int rb_res = rb.size();
            t3 = std::chrono::high_resolution_clock::now();
            //std::cout << std::format("cart = {}, rb = {}\n", cart_res, rb_res);
            if (cart_res != rb_res)
                throw std::runtime_error("PROBLEMA 3!!!");
            cart_sz += t2 - t1;
            rb_sz += t3 - t2;
            ++n_sz;
        }
    }
    std::cout << std::format(
            "Insert in cart took {}\n", cart_ins / n_ins
    );
    std::cout << std::format(
            "Insert in rb   took {}\n", rb_ins / n_ins
    );
    std::cout << std::format(
            "Find   in cart took {}\n", cart_find / n_find
    );
    std::cout << std::format(
            "Find   in rb   took {}\n", rb_find / n_find
    );
    std::cout << std::format(
            "Remove in cart took {}\n", cart_rem / n_rem
    );
    std::cout << std::format(
            "Remove in rb   took {}\n", rb_rem / n_rem
    );
    std::cout << std::format(
            "Size   in cart took {}\n", cart_sz / n_sz
    );
    std::cout << std::format(
            "Size   in rb   took {}\n", rb_sz / n_sz
    );
}

int main() {
    // Descartes
    stress_smart_array();
    return 0;
}
