// You can find a digital copy of this project at https://godbolt.org/z/5s4jW1qW5
#include <random>
#include <vector>
#include <format>
#include <iostream>
#include <chrono>
#include <set>
#include "cartesian.hpp"
#include "cartesian_sized.hpp"
#include "cartesian_xless.hpp"
#include "cartesian_yless.hpp"
#include "vector_wrap.hpp"

using namespace std::chrono_literals;

template<class Set1, class Set2>
void benchmark_two_sets(const std::string& name_1, const std::string& name_2, int OPERATIONS) {
    std::cout << std::format("Benchmarking Set: '{}' against '{}'\n", name_1, name_2);
    // 1. insert(x)
    // 2. erase(x)
    // 3. contains(x)
    const int MAX_X = OPERATIONS / 10;
    std::vector<std::pair<int, int>> pairs(OPERATIONS);
    std::mt19937 rng_different_each_time(
        std::chrono::high_resolution_clock::now().time_since_epoch().count()
    );
    std::mt19937 rng;
    using uid = std::uniform_int_distribution<int>;
    uid x_distr(0, MAX_X);
    // for (auto& [type, x] : pairs) {
    //     type = uid(1, 3)(rng);;
    //     x = x_distr(rng);
    // }
    for (int i = 0; i < OPERATIONS; ++i) {
        pairs[i].first = uid(1, 3)(rng);
        pairs[i].second = x_distr(rng);
    }

    size_t sum_of_sizes = 0;
    std::vector<int> answers_1;
    answers_1.reserve(OPERATIONS);
    auto t1_start = std::chrono::high_resolution_clock::now();
    {
        Set1 set_1;
        for (const auto [type, x] : pairs) {
            if (type == 1) {
                set_1.insert(x);
            } else if (type == 2) {
                answers_1.push_back(set_1.erase(x));
            } else if (type == 3) {
                answers_1.push_back(set_1.contains(x));
            }
            sum_of_sizes += set_1.size();
        }
    }
    auto t1_end = std::chrono::high_resolution_clock::now();

    std::vector<int> answers_2;
    answers_2.reserve(OPERATIONS);
    auto t2_start = std::chrono::high_resolution_clock::now();
    {
        Set2 set_2;
        for (const auto [type, x] : pairs) {
            if (type == 1) {
                set_2.insert(x);
            } else if (type == 2) {
                answers_2.push_back(set_2.erase(x));
            } else if (type == 3) {
                answers_2.push_back(set_2.contains(x));
            }
        }
    }
    auto t2_end = std::chrono::high_resolution_clock::now();

    std::cout << std::format("{}ns per operation in {}\n", (t1_end - t1_start) / 1.ns / OPERATIONS, name_1);
    std::cout << std::format("{}ns per operation in {}\n", (t2_end - t2_start) / 1.ns / OPERATIONS, name_2);
    std::cout << std::format("Mean size is {}\n", (double)sum_of_sizes / OPERATIONS);
    if (answers_1 != answers_2)
        std::cout << std::format("EVERYTHING IS BAD\n");
    std::cout << '\n';
}

template<class Set1, class Set2>
void benchmark_two_sets_with_sizes(const std::string& name_1, const std::string& name_2, int OPERATIONS) {
    std::cout << std::format("Benchmarking Set (with size query): '{}' against '{}'\n", name_1, name_2);
    // 1. insert(x)
    // 2. erase(x)
    // 3. contains(x)
    // 4. size()
    const int MAX_X = OPERATIONS / 10;
    std::vector<std::pair<int, int>> pairs(OPERATIONS);
    std::mt19937 rng_different_each_time(
        std::chrono::high_resolution_clock::now().time_since_epoch().count()
    );
    std::mt19937 rng;
    using uid = std::uniform_int_distribution<int>;
    uid x_distr(0, MAX_X);
    // for (auto& [type, x] : pairs) {
    //     type = uid(1, 3)(rng);;
    //     x = x_distr(rng);
    // }
    for (int i = 0; i < OPERATIONS; ++i) {
        pairs[i].first = uid(1, 4)(rng);
        pairs[i].second = x_distr(rng);
    }

    size_t sum_of_sizes = 0;
    std::vector<int> answers_1;
    answers_1.reserve(OPERATIONS);
    auto t1_start = std::chrono::high_resolution_clock::now();
    {
        Set1 set_1;
        for (const auto [type, x] : pairs) {
            if (type == 1) {
                set_1.insert(x);
            } else if (type == 2) {
                answers_1.push_back(set_1.erase(x));
            } else if (type == 3) {
                answers_1.push_back(set_1.contains(x));
            } else if (type == 4) {
                answers_1.push_back(set_1.size());
            }
            sum_of_sizes += set_1.size();
        }
    }
    auto t1_end = std::chrono::high_resolution_clock::now();

    std::vector<int> answers_2;
    answers_2.reserve(OPERATIONS);
    auto t2_start = std::chrono::high_resolution_clock::now();
    {
        Set2 set_2;
        for (const auto [type, x] : pairs) {
            if (type == 1) {
                set_2.insert(x);
            } else if (type == 2) {
                answers_2.push_back(set_2.erase(x));
            } else if (type == 3) {
                answers_2.push_back(set_2.contains(x));
            } else if (type == 4) {
                answers_2.push_back(set_2.size());
            }
        }
    }
    auto t2_end = std::chrono::high_resolution_clock::now();

    std::cout << std::format("{}ns per operation in {}\n", (t1_end - t1_start) / 1.ns / OPERATIONS, name_1);
    std::cout << std::format("{}ns per operation in {}\n", (t2_end - t2_start) / 1.ns / OPERATIONS, name_2);
    std::cout << std::format("Mean size is {}\n", (double)sum_of_sizes / OPERATIONS);
    if (answers_1 != answers_2)
        std::cout << std::format("EVERYTHING IS BAD\n");
    std::cout << '\n';
}

template<class Set1, class Set2>
void benchmark_smart_array(const std::string& name_1, const std::string& name_2, int OPERATIONS) {
    std::cout << std::format("Benchmarking SmartArray (with size query): '{}' against '{}'\n", name_1, name_2);
    // 1. insert(i, x)
    // 2. erase(i)
    // 3. at(i)
    // 4. size()
    const int MAX_X = OPERATIONS / 10;
    std::vector<std::array<int, 3>> pairs(OPERATIONS);
    std::mt19937 rng_different_each_time(
        std::chrono::high_resolution_clock::now().time_since_epoch().count()
    );
    std::mt19937 rng;
    using uid = std::uniform_int_distribution<int>;
    uid x_distr(0, MAX_X);
    // for (auto& [type, x] : pairs) {
    //     type = uid(1, 3)(rng);;
    //     x = x_distr(rng);
    // }
    int size = 0;
    for (int i = 0; i < OPERATIONS; ++i) {
        pairs[i][0] = uid(0, 4)(rng); // [0; 4] instead of [1; 4]: we want to double the probability of INSERT
        if (pairs[i][0] == 0) pairs[i][0] = 1;
        if (size == 0 && (pairs[i][0] == 2 || pairs[i][0] == 3))
            pairs[i][0] = 4;
        int l = 0;
        int r = 0;
        if (pairs[i][0] == 1) {
            r = size;
        } else if (pairs[i][0] == 2 || pairs[i][0] == 3) {
            r = size - 1;
        }
        pairs[i][1] = uid(l, r)(rng);
        pairs[i][2] = x_distr(rng);
        if (pairs[i][0] == 1)
            ++size;
        else if (pairs[i][0] == 2)
            --size;
    }

    size_t sum_of_sizes = 0;
    std::vector<int> answers_1;
    answers_1.reserve(OPERATIONS);
    auto t1_start = std::chrono::high_resolution_clock::now();
    {
        Set1 set_1;
        for (const auto [type, i, x] : pairs) {
            if (type == 1) {
                set_1.insert(i, x);
            } else if (type == 2) {
                answers_1.push_back(set_1.erase(i));
            } else if (type == 3) {
                answers_1.push_back(set_1.at(i));
            } else if (type == 4) {
                answers_1.push_back(set_1.size());
            }
            sum_of_sizes += set_1.size();
        }
    }
    auto t1_end = std::chrono::high_resolution_clock::now();

    std::vector<int> answers_2;
    answers_2.reserve(OPERATIONS);
    auto t2_start = std::chrono::high_resolution_clock::now();
    {
        Set2 set_2;
        for (const auto [type, i, x] : pairs) {
            if (type == 1) {
                set_2.insert(i, x);
            } else if (type == 2) {
                answers_2.push_back(set_2.erase(i));
            } else if (type == 3) {
                answers_2.push_back(set_2.at(i));
            } else if (type == 4) {
                answers_2.push_back(set_2.size());
            }
        }
    }
    auto t2_end = std::chrono::high_resolution_clock::now();

    std::cout << std::format("{}ns per operation in {}\n", (t1_end - t1_start) / 1.ns / OPERATIONS, name_1);
    std::cout << std::format("{}ns per operation in {}\n", (t2_end - t2_start) / 1.ns / OPERATIONS, name_2);
    std::cout << std::format("Mean size is {}\n", (double)sum_of_sizes / OPERATIONS);
    if (answers_1 != answers_2)
        std::cout << std::format("EVERYTHING IS BAD\n");
    std::cout << '\n';
}

int main() {
    int operations;
    std::cin >> operations;
    benchmark_two_sets<std::set<int>, cartesian::Set<int>>("std::set<int>", "cartesian::Set<int>", operations);
    benchmark_two_sets_with_sizes<std::set<int>, cartesian_sized::Set<int>>("std::set<int>", "cartesian_sized::Set<int>", operations);
    benchmark_two_sets_with_sizes<std::set<int>, cartesian_yless::Set<int>>("std::set<int>", "cartesian_yless::Set<int>", operations);
    benchmark_smart_array<vector_wrap::SmartArray<int>, cartesian_xless::SmartArray<int>>("vector_wrap::SmartArray<int>", "cartesian_xless::SmartArray<int>", operations);
    return 0;


    const int N = 9;
    std::vector<int> x(9), y(9);
    for (int i = 0; i < N; ++i) {
        x[i] = i;
        y[i] = i;
    }
    std::mt19937 rng;
    std::shuffle(y.begin(), y.end(), rng);
    std::vector<std::pair<int, int>> pairs;
    for (int i = 0; i < N; ++i) {
        std::cout << std::format("({}, {}) ", x[i], y[i]);
        pairs.emplace_back(x[i], y[i]);
    }
    std::cout << '\n';
    std::shuffle(pairs.begin(), pairs.end(), rng);

    cartesian::Tree<int, int> tree;
    for (const auto& [x, y] : pairs) {
        tree.insert(x, y);
        std::cout << std::format("({} {}): {}\n", x, y, tree.to_string());
    }

    const auto [string, levels] = tree.to_leveled_string();
    int printed = 0;
    for (int level = 0; printed < string.size(); ++level) {
        for (int i = 0; i < ssize(string); ++i)
            if (levels[i] == level) {
                std::cout << string[i];
                ++printed;
            } else {
                std::cout << ' ';
            }
        std::cout << '\n';
    }
    auto check_have = [&](int x) {
        std::optional<int> ans =  tree.find(x);
        if (ans /*.has_value()*/)
            std::cout << std::format("We have ({}, {})\n",
                x, *ans);
        else
            std::cout << x << " is std::nullopt\n";
    };
}
