From 16d880a67ef662ac00de363ad533bc75d0deb7fa Mon Sep 17 00:00:00 2001
From: Weiqi <weltch1997@gmail.com>
Date: Thu, 27 Apr 2023 03:05:16 -0400
Subject: [PATCH] refactor

---
 include/field.h       |  24 +++++-----
 include/group.h       |   4 +-
 include/helper.h      |   4 ++
 include/hnsw.h        |  57 +++++++++++++++++++++++-
 include/ipre.h        |  10 ++---
 include/matrix.h      |   2 +-
 include/vector.h      |   2 +-
 src/field.cpp         |  40 ++++++++---------
 src/group.cpp         |   4 +-
 src/helper.cpp        |  10 ++++-
 src/hnsw.cpp          | 101 ++++++++++++++++++++++++++++++++++++++++++
 src/ipre.cpp          |  14 +++---
 src/matrix.cpp        |  18 ++++----
 src/vector.cpp        |   8 ++--
 tests/test_field.cpp  |  34 +++++++-------
 tests/test_group.cpp  |   4 +-
 tests/test_helper.cpp |  29 +++++++++++-
 tests/test_ipre.cpp   |   6 +--
 18 files changed, 282 insertions(+), 89 deletions(-)

diff --git a/include/field.h b/include/field.h
index 67cd008..0560743 100644
--- a/include/field.h
+++ b/include/field.h
@@ -7,31 +7,31 @@ extern "C" {
 #include "relic/relic.h"
 }
 
-struct zp {
+struct ZP {
     bn_t point{};
     bn_t modular{};
 };
 
-zp rand_zp(bn_t modular);
+ZP rand_zp(bn_t modular);
 
-zp zp_zero(bn_t modular);
+ZP zp_zero(bn_t modular);
 
-zp zp_one(bn_t modular);
+ZP zp_one(bn_t modular);
 
-zp zp_copy(zp x);
+ZP zp_copy(ZP x);
 
-zp zp_from_int(int x, bn_t modular);
+ZP zp_from_int(int x, bn_t modular);
 
-zp zp_add(zp x, zp y);
+ZP zp_add(ZP x, ZP y);
 
-zp zp_neg(zp x);
+ZP zp_neg(ZP x);
 
-zp zp_mul(zp x, zp y);
+ZP zp_mul(ZP x, ZP y);
 
-zp zp_inv(zp x);
+ZP zp_inv(ZP x);
 
-int zp_cmp(zp x, zp y);
+int zp_cmp(ZP x, ZP y);
 
-int zp_cmp_int(zp x, int y);
+int zp_cmp_int(ZP x, int y);
 
 #endif //PPANN_FIELD_H
\ No newline at end of file
diff --git a/include/group.h b/include/group.h
index 6b956f1..bd0a3db 100644
--- a/include/group.h
+++ b/include/group.h
@@ -8,9 +8,9 @@ typedef gt_t gt;
 
 void gen(g x);
 
-void g_mul(g r, g x, zp y);
+void g_mul(g r, g x, ZP y);
 
-void gt_raise(gt r, gt x, zp y);
+void gt_raise(gt r, gt x, ZP y);
 
 void bp_map(g a, g b, gt r);
 
diff --git a/include/helper.h b/include/helper.h
index 2ddca45..0525ee3 100644
--- a/include/helper.h
+++ b/include/helper.h
@@ -4,9 +4,13 @@
 #include <cstdio>
 #include <cstring>
 #include <sys/stat.h>
+#include "hnsw.h"
+#include "ipre.h"
 
 float *fvecs_read(const char *file_path, size_t *d_out, size_t *n_out);
 
 int *ivecs_read(const char *file_path, size_t *d_out, size_t *n_out);
 
+Ct *encrypt_data(int *data, Key key, int d, int n);
+
 #endif //PPANN_HELPER_H
\ No newline at end of file
diff --git a/include/hnsw.h b/include/hnsw.h
index 759e69a..8626601 100644
--- a/include/hnsw.h
+++ b/include/hnsw.h
@@ -1,10 +1,65 @@
 #ifndef PPANN_HNSW_H
 #define PPANN_HNSW_H
 
+#include <set>
+#include <random>
+#include <vector>
+#include <algorithm>
+#include <unordered_map>
+#include <unordered_set>
+#include "ipre.h"
 
-class hnsw {
+using namespace std;
 
+
+struct Item {
+    // The ciphertext as value.
+    Ct value;
+
+    // Compute distance between item with something else.
+    int dist(Item &other, Key key, int size, int bound) const;
 };
 
 
+struct HNSWGraph {
+    // Constructor.
+    HNSWGraph(int NN, int MN, int MNZ, int SN, int ML, Key key, int size, int bound);
+
+    /* HNSW related settings. */
+    // Number of neighbors.
+    int NN;
+    // Max number of neighbors in layers >= 1.
+    int MN;
+    // Max number of neighbors in layers 0.
+    int MNZ;
+    // search numbers in construction (efConstruction).
+    int SN;
+    // Max number of layers.
+    int ML;
+    // number of items
+    int numItem;
+    // enter node id
+    int enterNode{};
+    // actual vector of the items
+    vector<Item> items;
+    // adjacent edge lists in each layer
+    vector<unordered_map<int, vector<int>>> layerEdgeLists;
+    // The default generator.
+    default_random_engine generator;
+
+    /* For the IPRE scheme. */
+    Key key;
+    int size;
+    int bound;
+
+    /* Methods. */
+    void insert(Item &q);
+
+    void addEdge(int st, int ed, int lc);
+
+    vector<int> search(Item &q, int K);
+
+    vector<int> searchLayer(Item &q, int ep, int ef, int lc);
+};
+
 #endif //PPANN_HNSW_H
diff --git a/include/ipre.h b/include/ipre.h
index cd84869..10d3b18 100644
--- a/include/ipre.h
+++ b/include/ipre.h
@@ -6,7 +6,7 @@
 
 const int B_SIZE = 6;
 
-struct key {
+struct Key {
     zp_mat A;
     zp_mat B;
     zp_mat Bi;
@@ -15,16 +15,16 @@ struct key {
     bn_t modular;
 };
 
-struct ct {
+struct Ct {
     g_vec ctx;
     g_vec ctk;
     g_vec ctc;
 };
 
-key setup(int size);
+Key setup(int size);
 
-ct enc(key key, const int *message, int size);
+Ct enc(Key key, const int *message, int size);
 
-int eval(key key, ct x, ct y, int size, int bound);
+int eval(Key key, Ct x, Ct y, int size, int bound);
 
 #endif //PPANN_IPRE_H
\ No newline at end of file
diff --git a/include/matrix.h b/include/matrix.h
index f0f7f6d..2d64931 100644
--- a/include/matrix.h
+++ b/include/matrix.h
@@ -4,7 +4,7 @@
 #include "field.h"
 #include "group.h"
 
-typedef zp *zp_mat;
+typedef ZP *zp_mat;
 
 zp_mat matrix_zp_from_int(const int *int_mat, int row, int col, bn_t modular);
 
diff --git a/include/vector.h b/include/vector.h
index 355f935..453a6f8 100644
--- a/include/vector.h
+++ b/include/vector.h
@@ -4,7 +4,7 @@
 #include "field.h"
 #include "group.h"
 
-typedef zp *zp_vec;
+typedef ZP *zp_vec;
 typedef g *g_vec;
 
 zp_vec vector_zp_from_int(const int *int_vec, int size, bn_t modular);
diff --git a/src/field.cpp b/src/field.cpp
index 21b9119..8ef78ff 100644
--- a/src/field.cpp
+++ b/src/field.cpp
@@ -1,75 +1,75 @@
 #include "field.h"
 
-zp rand_zp(bn_st *modular) {
-    zp result;
+ZP rand_zp(bn_st *modular) {
+    ZP result;
     bn_rand_mod(result.point, modular);
     bn_copy(result.modular, modular);
     return result;
 }
 
-zp zp_zero(bn_st *modular) {
-    zp result;
+ZP zp_zero(bn_st *modular) {
+    ZP result;
     bn_set_dig(result.point, 0);
     bn_copy(result.modular, modular);
     return result;
 }
 
-zp zp_one(bn_st *modular) {
-    zp result;
+ZP zp_one(bn_st *modular) {
+    ZP result;
     bn_set_dig(result.point, 1);
     bn_copy(result.modular, modular);
     return result;
 }
 
-zp zp_copy(zp x) {
-    zp result;
+ZP zp_copy(ZP x) {
+    ZP result;
     bn_copy(result.point, x.point);
     bn_copy(result.modular, x.modular);
     return result;
 }
 
-zp zp_from_int(int x, bn_st *modular) {
-    zp result;
+ZP zp_from_int(int x, bn_st *modular) {
+    ZP result;
     bn_set_dig(result.point, x);
     bn_copy(result.modular, modular);
     return result;
 }
 
-zp zp_add(zp x, zp y) {
-    zp result;
+ZP zp_add(ZP x, ZP y) {
+    ZP result;
     bn_add(result.point, x.point, y.point);
     bn_mod(result.point, result.point, x.modular);
     bn_copy(result.modular, x.modular);
     return result;
 }
 
-zp zp_neg(zp x) {
-    zp result;
+ZP zp_neg(ZP x) {
+    ZP result;
     bn_neg(result.point, x.point);
     bn_mod(result.point, result.point, x.modular);
     bn_copy(result.modular, x.modular);
     return result;
 }
 
-zp zp_mul(zp x, zp y) {
-    zp result;
+ZP zp_mul(ZP x, ZP y) {
+    ZP result;
     bn_mul(result.point, x.point, y.point);
     bn_mod(result.point, result.point, x.modular);
     bn_copy(result.modular, x.modular);
     return result;
 }
 
-zp zp_inv(zp x) {
-    zp result;
+ZP zp_inv(ZP x) {
+    ZP result;
     bn_mod_inv(result.point, x.point, x.modular);
     bn_copy(result.modular, x.modular);
     return result;
 }
 
-int zp_cmp(zp x, zp y) {
+int zp_cmp(ZP x, ZP y) {
     return bn_cmp(x.point, y.point) == RLC_EQ;
 }
 
-int zp_cmp_int(zp x, int y) {
+int zp_cmp_int(ZP x, int y) {
     return bn_cmp_dig(x.point, y) == RLC_EQ;
 }
\ No newline at end of file
diff --git a/src/group.cpp b/src/group.cpp
index 6558d99..e358225 100644
--- a/src/group.cpp
+++ b/src/group.cpp
@@ -4,11 +4,11 @@ void gen(ep_st *x) {
     g1_get_gen(x);
 }
 
-void g_mul(ep_st *r, ep_st *x, zp y) {
+void g_mul(ep_st *r, ep_st *x, ZP y) {
     g1_mul(r, x, y.point);
 }
 
-void gt_raise(fp_t *r, fp_t *x, zp y) {
+void gt_raise(fp_t *r, fp_t *x, ZP y) {
     gt_exp(r, x, y.point);
 }
 
diff --git a/src/helper.cpp b/src/helper.cpp
index ad4f426..e9fdb5d 100644
--- a/src/helper.cpp
+++ b/src/helper.cpp
@@ -37,4 +37,12 @@ float *fvecs_read(const char *file_path, size_t *d_out, size_t *n_out) {
 int *ivecs_read(const char *file_path, size_t *d_out, size_t *n_out) {
     // Cast the float results to integers.
     return (int *) fvecs_read(file_path, d_out, n_out);
-}
\ No newline at end of file
+}
+
+Ct *encrypt_data(int *data, Key key, int d, int n) {
+    auto *encrypted_data = new Ct[n];
+
+    for (int i = 0; i < n; i++) encrypted_data[i] = enc(key, &data[i * d], d);
+
+    return encrypted_data;
+}
diff --git a/src/hnsw.cpp b/src/hnsw.cpp
index 0020eb0..cbfe71c 100644
--- a/src/hnsw.cpp
+++ b/src/hnsw.cpp
@@ -1 +1,102 @@
 #include "hnsw.h"
+
+
+int Item::dist(Item &other, Key key, int size, int bound) const {
+    return eval(key, value, other.value, size, bound);
+}
+
+HNSWGraph::HNSWGraph(int NN, int MN, int MNZ, int SN, int ML, Key key, int size, int bound) :
+        NN(NN), MN(MN), MNZ(MNZ), SN(SN), ML(ML), key(key), size(size), bound(bound) {
+    numItem = 0;
+    enterNode = 0;
+    layerEdgeLists.emplace_back();
+}
+
+void HNSWGraph::insert(Item &q) {
+    int nid = items.size();
+    numItem++;
+    items.push_back(q);
+    // sample layer
+    int maxLyer = layerEdgeLists.size() - 1;
+    int l = 0;
+    uniform_real_distribution<double> distribution(0.0, 1.0);
+    while (l < ML && (1.0 / ML <= distribution(generator))) {
+        l++;
+        if (layerEdgeLists.size() <= l) layerEdgeLists.emplace_back();
+    }
+    if (nid == 0) {
+        enterNode = nid;
+        return;
+    }
+    // search up layer entrance
+    int ep = enterNode;
+    for (int i = maxLyer; i > l; i--) ep = searchLayer(q, ep, 1, i)[0];
+    for (int i = min(l, maxLyer); i >= 0; i--) {
+        int MM = l == 0 ? MNZ : MN;
+        vector<int> neighbors = searchLayer(q, ep, SN, i);
+        vector<int> selectedNeighbors = vector<int>(neighbors.begin(),
+                                                    neighbors.begin() + min(int(neighbors.size()), NN));
+        for (int n: selectedNeighbors) addEdge(n, nid, i);
+        for (int n: selectedNeighbors) {
+            if (layerEdgeLists[i][n].size() > MM) {
+                vector<pair<double, int>> distPairs;
+                for (int nn: layerEdgeLists[i][n])
+                    distPairs.emplace_back(items[n].dist(items[nn], key, size, bound), nn);
+                sort(distPairs.begin(), distPairs.end());
+                layerEdgeLists[i][n].clear();
+                for (int d = 0; d < min(int(distPairs.size()), MM); d++)
+                    layerEdgeLists[i][n].push_back(distPairs[d].second);
+            }
+        }
+        ep = selectedNeighbors[0];
+    }
+    if (l == layerEdgeLists.size() - 1) enterNode = nid;
+}
+
+void HNSWGraph::addEdge(int st, int ed, int lc) {
+    if (st == ed) return;
+    layerEdgeLists[lc][st].push_back(ed);
+    layerEdgeLists[lc][ed].push_back(st);
+}
+
+vector<int> HNSWGraph::search(Item &q, int K) {
+    auto maxLyer = layerEdgeLists.size() - 1;
+    int ep = enterNode;
+    for (auto l = maxLyer; l >= 1; l--) ep = searchLayer(q, ep, 1, l)[0];
+    return searchLayer(q, ep, K, 0);
+}
+
+vector<int> HNSWGraph::searchLayer(Item &q, int ep, int ef, int lc) {
+    set<pair<double, int>> candidates;
+    set<pair<double, int>> nearestNeighbors;
+    unordered_set<int> isVisited;
+
+    double td = q.dist(items[ep], key, size, bound);
+    candidates.insert(make_pair(td, ep));
+    nearestNeighbors.insert(make_pair(td, ep));
+    isVisited.insert(ep);
+    while (!candidates.empty()) {
+        auto ci = candidates.begin();
+        candidates.erase(candidates.begin());
+        int nid = ci->second;
+        auto fi = nearestNeighbors.end();
+        fi--;
+        if (ci->first > fi->first) break;
+        for (int ed: layerEdgeLists[lc][nid]) {
+            if (isVisited.find(ed) != isVisited.end()) continue;
+            fi = nearestNeighbors.end();
+            fi--;
+            isVisited.insert(ed);
+            td = q.dist(items[ed], key, size, bound);
+            if ((td < fi->first) || nearestNeighbors.size() < ef) {
+                candidates.insert(make_pair(td, ed));
+                nearestNeighbors.insert(make_pair(td, ed));
+                if (nearestNeighbors.size() > ef) nearestNeighbors.erase(fi);
+            }
+        }
+    }
+    vector<int> results;
+    results.reserve(nearestNeighbors.size());
+    for (auto &p: nearestNeighbors) results.push_back(p.second);
+    return results;
+}
diff --git a/src/ipre.cpp b/src/ipre.cpp
index 47e79ba..fb01a06 100644
--- a/src/ipre.cpp
+++ b/src/ipre.cpp
@@ -1,7 +1,7 @@
 #include "ipre.h"
 
-key setup(int size) {
-    key key{};
+Key setup(int size) {
+    Key key{};
     pc_get_ord(key.modular);
     gen(key.base);
     bp_map(key.base, key.base, key.t_base);
@@ -11,9 +11,9 @@ key setup(int size) {
     return key;
 }
 
-ct enc(key key, const int *message, int size) {
-    // Declare the returned ciphertext and convert message to zp.
-    ct ct{};
+Ct enc(Key key, const int *message, int size) {
+    // Declare the returned ciphertext and convert message to ZP.
+    Ct ct{};
     zp_vec x = vector_zp_from_int(message, size, key.modular);
 
     // Helper values.
@@ -27,7 +27,7 @@ ct enc(key key, const int *message, int size) {
     zp_vec sAx = vector_add(sA, x, size);
     ct.ctx = vector_raise(key.base, sAx, size);
 
-    // We compute the function hiding inner product encryption key.
+    // We compute the function hiding inner product encryption Key.
     zp_mat AT = matrix_transpose(key.A, 2, size);
     zp_vec xAT = matrix_multiply(x, AT, 1, size, 2, key.modular);
     zp_vec xATs = vector_merge(xAT, s, 2, 2);
@@ -48,7 +48,7 @@ ct enc(key key, const int *message, int size) {
     return ct;
 }
 
-int eval(key key, ct x, ct y, int size, int bound) {
+int eval(Key key, Ct x, Ct y, int size, int bound) {
     // Decrypt components.
     gt xy, ct;
     inner_product(xy, x.ctx, y.ctx, size);
diff --git a/src/matrix.cpp b/src/matrix.cpp
index 27836f1..bf66e09 100644
--- a/src/matrix.cpp
+++ b/src/matrix.cpp
@@ -2,7 +2,7 @@
 
 zp_mat matrix_zp_from_int(const int *int_mat, int row, int col, bn_st *modular) {
     zp_mat x;
-    x = (zp_mat) malloc(sizeof(zp) * row * col);
+    x = (zp_mat) malloc(sizeof(ZP) * row * col);
     for (int i = 0; i < row; i++) {
         for (int j = 0; j < col; j++) {
             x[i * col + j] = zp_from_int(int_mat[i * col + j], modular);
@@ -13,7 +13,7 @@ zp_mat matrix_zp_from_int(const int *int_mat, int row, int col, bn_st *modular)
 
 zp_mat matrix_zp_rand(int row, int col, bn_st *modular) {
     zp_mat x;
-    x = (zp_mat) malloc(sizeof(zp) * row * col);
+    x = (zp_mat) malloc(sizeof(ZP) * row * col);
     for (int i = 0; i < row; i++) {
         for (int j = 0; j < col; j++) {
             x[i * col + j] = rand_zp(modular);
@@ -24,7 +24,7 @@ zp_mat matrix_zp_rand(int row, int col, bn_st *modular) {
 
 zp_mat matrix_identity(int size, bn_st *modular) {
     zp_mat x;
-    x = (zp_mat) malloc(sizeof(zp) * size * size);
+    x = (zp_mat) malloc(sizeof(ZP) * size * size);
     for (int i = 0; i < size; i++) {
         for (int j = 0; j < size; j++) {
             if (i == j) x[i * size + j] = zp_one(modular);
@@ -46,7 +46,7 @@ int matrix_is_identity(zp_mat x, int size) {
 
 zp_mat matrix_transpose(zp_mat x, int row, int col) {
     zp_mat xt;
-    xt = (zp_mat) malloc(sizeof(zp) * row * col);
+    xt = (zp_mat) malloc(sizeof(ZP) * row * col);
     for (int i = 0; i < row; i++) {
         for (int j = 0; j < col; j++) {
             xt[j * row + i] = zp_copy(x[i * col + j]);
@@ -57,7 +57,7 @@ zp_mat matrix_transpose(zp_mat x, int row, int col) {
 
 zp_mat matrix_merge(zp_mat x, zp_mat y, int row, int col_x, int col_y) {
     zp_mat xy;
-    xy = (zp_mat) malloc(sizeof(zp) * row * (col_x + col_y));
+    xy = (zp_mat) malloc(sizeof(ZP) * row * (col_x + col_y));
     for (int i = 0; i < row; i++) {
         for (int j = 0; j < col_x; j++) {
             xy[i * (col_x + col_y) + j] = zp_copy(x[i * col_x + j]);
@@ -70,7 +70,7 @@ zp_mat matrix_merge(zp_mat x, zp_mat y, int row, int col_x, int col_y) {
 }
 
 zp_mat matrix_multiply(zp_mat x, zp_mat y, int row_x, int row_y, int col_y, bn_st *modular) {
-    auto xy = (zp_mat) malloc(sizeof(zp) * row_x * col_y);
+    auto xy = (zp_mat) malloc(sizeof(ZP) * row_x * col_y);
 
     for (int i = 0; i < row_x; i++) {
         for (int j = 0; j < col_y; j++) {
@@ -89,8 +89,8 @@ zp_mat matrix_inverse(zp_mat x, int size, bn_st *modular) {
     zp_mat row_echelon = matrix_merge(x, identity, size, size, size);
 
     // Declare temp value.
-    zp temp_multiplier;
-    zp temp_neg;
+    ZP temp_multiplier;
+    ZP temp_neg;
 
     // Bottom left half to all zeros.
     for (int i = 0; i < size; i++) {
@@ -129,7 +129,7 @@ zp_mat matrix_inverse(zp_mat x, int size, bn_st *modular) {
 
     // Copy over the output.
     zp_mat xi;
-    xi = (zp_mat) malloc(sizeof(zp) * size * size);
+    xi = (zp_mat) malloc(sizeof(ZP) * size * size);
     for (int i = 0; i < size; i++) {
         for (int j = 0; j < size; j++) {
             xi[i * size + j] = zp_copy(row_echelon[i * 2 * size + size + j]);
diff --git a/src/vector.cpp b/src/vector.cpp
index 33de448..e34bddb 100644
--- a/src/vector.cpp
+++ b/src/vector.cpp
@@ -2,21 +2,21 @@
 
 zp_vec vector_zp_from_int(const int *int_vec, int size, bn_st *modular) {
     zp_vec x;
-    x = (zp_vec) malloc(sizeof(zp) * size);
+    x = (zp_vec) malloc(sizeof(ZP) * size);
     for (int i = 0; i < size; i++) x[i] = zp_from_int(int_vec[i], modular);
     return x;
 }
 
 zp_vec vector_zp_rand(int size, bn_st *modular) {
     zp_vec x;
-    x = (zp_vec) malloc(sizeof(zp) * size);
+    x = (zp_vec) malloc(sizeof(ZP) * size);
     for (int i = 0; i < size; i++) x[i] = rand_zp(modular);
     return x;
 }
 
 zp_vec vector_merge(zp_vec a, zp_vec b, int size_a, int size_b) {
     zp_vec r;
-    r = (zp_vec) malloc(sizeof(zp) * (size_a + size_b));
+    r = (zp_vec) malloc(sizeof(ZP) * (size_a + size_b));
     for (int i = 0; i < size_a; i++) r[i] = zp_copy(a[i]);
     for (int i = 0; i < size_b; i++) r[i + size_a] = zp_copy(b[i]);
     return r;
@@ -24,7 +24,7 @@ zp_vec vector_merge(zp_vec a, zp_vec b, int size_a, int size_b) {
 
 zp_vec vector_add(zp_vec a, zp_vec b, int size) {
     zp_vec r;
-    r = (zp_vec) malloc(sizeof(zp) * size);
+    r = (zp_vec) malloc(sizeof(ZP) * size);
     for (int i = 0; i < size; i++) r[i] = zp_add(a[i], b[i]);
     return r;
 }
diff --git a/tests/test_field.cpp b/tests/test_field.cpp
index f5f4e3e..c86c2c1 100644
--- a/tests/test_field.cpp
+++ b/tests/test_field.cpp
@@ -1,51 +1,51 @@
 #include "field.h"
 
 int test_zp_zero(bn_st *N) {
-    zp x = zp_zero(N);
+    ZP x = zp_zero(N);
     return zp_cmp_int(x, 0);
 }
 
 int test_zp_one(bn_st *N) {
-    zp x = zp_one(N);
+    ZP x = zp_one(N);
     return zp_cmp_int(x, 1);
 }
 
 int test_zp_copy(bn_st *N) {
-    zp x = zp_from_int(10, N);
-    zp y = zp_copy(x);
+    ZP x = zp_from_int(10, N);
+    ZP y = zp_copy(x);
     return zp_cmp(x, y);
 }
 
 int test_zp_from_int(bn_st *N) {
-    zp x = zp_from_int(3, N);
+    ZP x = zp_from_int(3, N);
     return zp_cmp_int(x, 3);
 }
 
 int test_zp_add(bn_st *N) {
-    zp x = zp_from_int(10, N);
-    zp y = zp_from_int(20, N);
-    zp z = zp_add(x, y);
+    ZP x = zp_from_int(10, N);
+    ZP y = zp_from_int(20, N);
+    ZP z = zp_add(x, y);
     return zp_cmp_int(z, 30);
 }
 
 int test_zp_neg(bn_st *N) {
-    zp x = rand_zp(N);
-    zp y = zp_neg(x);
-    zp z = zp_add(x, y);
+    ZP x = rand_zp(N);
+    ZP y = zp_neg(x);
+    ZP z = zp_add(x, y);
     return zp_cmp_int(z, 0);
 }
 
 int test_zp_mul(bn_st *N) {
-    zp x = zp_from_int(10, N);
-    zp y = zp_from_int(20, N);
-    zp z = zp_mul(x, y);
+    ZP x = zp_from_int(10, N);
+    ZP y = zp_from_int(20, N);
+    ZP z = zp_mul(x, y);
     return zp_cmp_int(z, 200);
 }
 
 int test_zp_inv(bn_st *N) {
-    zp x = rand_zp(N);
-    zp y = zp_inv(x);
-    zp z = zp_mul(x, y);
+    ZP x = rand_zp(N);
+    ZP y = zp_inv(x);
+    ZP z = zp_mul(x, y);
     return zp_cmp_int(z, 1);
 }
 
diff --git a/tests/test_group.cpp b/tests/test_group.cpp
index 3905e51..57786da 100644
--- a/tests/test_group.cpp
+++ b/tests/test_group.cpp
@@ -8,8 +8,8 @@ int test_generator() {
 
 int test_all(bn_st *N) {
     // Set integers.
-    zp m = zp_from_int(5, N);
-    zp n = zp_from_int(25, N);
+    ZP m = zp_from_int(5, N);
+    ZP n = zp_from_int(25, N);
 
     // Declare variables.
     g a, b;
diff --git a/tests/test_helper.cpp b/tests/test_helper.cpp
index 0162cc4..74d5205 100644
--- a/tests/test_helper.cpp
+++ b/tests/test_helper.cpp
@@ -1,5 +1,8 @@
 #include "helper.h"
 
+#include "iostream"
+using namespace std;
+
 int test_read_fvecs() {
     // Set dimensions holders and get the data.
     size_t d, n;
@@ -30,8 +33,30 @@ int test_read_ivecs() {
     return 1;
 }
 
+int test_encrypt() {
+    // Set dimensions holders and get the data.
+    size_t d, n;
+    float *xd = fvecs_read("sift_query.fvecs", &d, &n);
+
+    cout << static_cast<int>(xd[1]) << endl;
+    //
+    int d_int, n_int;
+    d_int = static_cast<int>(d);
+    n_int = static_cast<int>(n);
+
+    //
+//    Key key = setup(d_int);
+//    Ct *encrypted_data = encrypt_data(gt, key, d_int, 1);
+    return 1;
+}
+
 int main() {
-    if (test_read_fvecs() != 1) return 1;
-    if (test_read_ivecs() != 1) return 1;
+    // Init core and setup.
+    core_init();
+    pc_param_set_any();
+
+//    if (test_read_fvecs() != 1) return 1;
+//    if (test_read_ivecs() != 1) return 1;
+    if (test_encrypt() != 1) return 1;
     return 0;
 }
\ No newline at end of file
diff --git a/tests/test_ipre.cpp b/tests/test_ipre.cpp
index d8d1ea8..6f90891 100644
--- a/tests/test_ipre.cpp
+++ b/tests/test_ipre.cpp
@@ -5,10 +5,10 @@ int test_scheme() {
     int x[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
     int y[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 100};
     // Initialize the scheme.
-    key key = setup(10);
+    Key key = setup(10);
     // Encrypt the messages.
-    ct ct_x = enc(key, x, 10);
-    ct ct_y = enc(key, y, 10);
+    Ct ct_x = enc(key, x, 10);
+    Ct ct_y = enc(key, y, 10);
     // Evaluate the two ciphertexts.
     int output = eval(key, ct_x, ct_y, 10, 150);
 
-- 
GitLab