#include "emp-sh2pc/emp-sh2pc.h"
#include <chrono>

#include <boost/random/laplace_distribution.hpp>
#include <boost/random/linear_congruential.hpp>

using namespace std::chrono;
using namespace emp;
using namespace std;

#define MAX_DATUM 10000
#define QUERY_MIN 0
#define QUERY_MAX MAX_DATUM

double sample_laplace(double mean, double scale)
{
	auto seed = rand() % MAX_DATUM;
	boost::minstd_rand generator(seed);

	auto laplace_distribution = boost::random::laplace_distribution<double>(mean, scale);

	return laplace_distribution(generator);
}

pair<double, double> get_parameters_for_laplace(double epsilon, int N, int n, double beta)
{
	double delta_c = log2(N) * log2(N);
	double delta = beta / n;

	double mean = -delta_c * log(delta * (exp(epsilon / delta_c) + 1.0)) / epsilon + delta_c;
	double scale = delta_c;

	return {mean, scale};
}

void shrinkwrap_range_query(int party, int data_size, double selectivity, double epsilon, double beta, int N, int n)
{
	auto *data = new Integer[data_size];
	auto query_min = Integer(32, QUERY_MIN, ALICE);
	auto query_max = Integer(32, QUERY_MAX, BOB);

	auto count = Integer(32, 0, ALICE);
	auto one = Integer(32, 1, BOB);

	// Here we generate the data synthetically, because this prototype makes a linear scan anyway.
	// That is, the result will be the same regardless of distribution.

	// Half from server 1 (Alice)
	for (int i = 0; i < data_size / 2; ++i)
	{
		data[i] = Integer(32, rand() % MAX_DATUM, ALICE);
	}

	// Half from server 2 (Bob)
	for (int i = data_size / 2; i < data_size; ++i)
	{
		data[i] = Integer(32, rand() % MAX_DATUM, BOB);
	}

	auto start = high_resolution_clock::now();

	// To answer a range query, Shrinkwrap first scans the input remembering which record is filtered in or out
	for (int i = 0; i < data_size; ++i)
	{
		// In this prototype we are not really interested in the true result, so we take time to compute the bit, but then discard it
		auto in = (data[i] > query_min) & (data[i] < query_max);
		// We keep the count of the result's size; to make it oblivious we always add a number (0 or 1), but we don't care which one here
		count = count + one;
	}

	// Shrinkwrap then sorts the records according to their filtering bit, but since bitonic sort's runtime is data-agnostic, we'll just sort the input
	sort(data, data_size);

	// Shrinkwrap says in the paper that it uses "truncated Laplacian", which, according to the description, is max(s, 0), where s is the sample.
	// We have adapted the truncated Laplacian mechanism to the hierarchical method that we used in Epsolute.
	auto [mean, scale] = get_parameters_for_laplace(epsilon, N, n, beta);
	cout << "The mean and scale are : " << mean << " and " << scale << endl;

	auto noise = (int)max(sample_laplace(mean, epsilon), 0.0);

	cout << "The scan/sort/noise took (microseconds): " << duration_cast<microseconds>(high_resolution_clock::now() - start).count() << endl;
	cout << "The noise is : " << noise << endl;

	// Here, suppose that the count is correct (i.e. data_size * selectivity) and we add noise.
	// We emulate actual "decryption" by .reveal() and since sending over a blob of result is quick compared to CPU work, we skip it.
	for (int i = 0; i < data_size * selectivity + noise; ++i)
	{
		data[i].reveal<int32_t>();
	}

	cout << "Total: " << duration_cast<microseconds>(high_resolution_clock::now() - start).count() << endl;

	delete[] data;
}

int main(int argc, char **argv)
{
	int data_size = 1000;
	double selectivity = 0.005;
	double epsilon = 0.679;
	double beta = 1.0 / (1 << 20);
	int N = 10000;
	int n = 4369;

	int port, party;
	parse_party_and_port(argv, &party, &port);

	if (argc > 3)
	{
		data_size = atoi(argv[3]);
	}

	cout << "party: " << (party == ALICE ? "ALICE" : "BOB") << endl;

	cout << "data_size: " << data_size << endl;
	cout << "selectivity: " << selectivity << endl;
	cout << "epsilon: " << epsilon << endl;
	cout << "beta: " << beta << endl;
	cout << "N: " << N << endl;
	cout << "n: " << n << endl;

	auto *io = new NetIO(party == ALICE ? nullptr : "10.142.0.18", port);

	setup_semi_honest(io, party);
	shrinkwrap_range_query(party, data_size, selectivity, epsilon, beta, N, n);
	finalize_semi_honest();

	delete io;
}