From 7e6fcb84ffff7e09f6fad28e4beebcff49adb2e2 Mon Sep 17 00:00:00 2001
From: Weiqi <weltch1997@gmail.com>
Date: Thu, 27 Apr 2023 05:18:59 -0400
Subject: [PATCH] Done with the hnsw

---
 include/helper.h      |  2 +-
 src/helper.cpp        |  6 ++---
 src/hnsw.cpp          | 26 ++++++++++----------
 tests/CMakeLists.txt  |  5 +++-
 tests/test_helper.cpp |  4 +--
 tests/test_hnsw.cpp   | 57 +++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 80 insertions(+), 20 deletions(-)
 create mode 100644 tests/test_hnsw.cpp

diff --git a/include/helper.h b/include/helper.h
index 6d76777..ae2291c 100644
--- a/include/helper.h
+++ b/include/helper.h
@@ -13,6 +13,6 @@ int *ivecs_read(const char *file_path, size_t *d_out, size_t *n_out);
 
 int *float_to_int(const float *data, size_t size);
 
-Ct *encrypt_data(const int *data, Key key, size_t d, size_t n);
+Item *encrypt_data(const int *data, Key key, size_t d, size_t n);
 
 #endif //PPANN_HELPER_H
\ No newline at end of file
diff --git a/src/helper.cpp b/src/helper.cpp
index 3723185..46b3271 100644
--- a/src/helper.cpp
+++ b/src/helper.cpp
@@ -50,12 +50,12 @@ int *float_to_int(const float *data, size_t size) {
     return int_data;
 }
 
-Ct *encrypt_data(const int *data, Key key, size_t d, size_t n) {
+Item *encrypt_data(const int *data, Key key, size_t d, size_t n) {
     // Get a new list for encrypted data.
-    auto *encrypted_data = new Ct[n];
+    auto *encrypted_data = new Item[n];
 
     // Encrypt each vector.
-    for (int i = 0; i < n; i++) encrypted_data[i] = enc(key, &data[i * d], static_cast<int>(d));
+    for (int i = 0; i < n; i++) encrypted_data[i].value = enc(key, &data[i * d], static_cast<int>(d));
 
     // Return pointer of the list.
     return encrypted_data;
diff --git a/src/hnsw.cpp b/src/hnsw.cpp
index cbfe71c..60c54da 100644
--- a/src/hnsw.cpp
+++ b/src/hnsw.cpp
@@ -1,6 +1,5 @@
 #include "hnsw.h"
 
-
 int Item::dist(Item &other, Key key, int size, int bound) const {
     return eval(key, value, other.value, size, bound);
 }
@@ -13,11 +12,11 @@ HNSWGraph::HNSWGraph(int NN, int MN, int MNZ, int SN, int ML, Key key, int size,
 }
 
 void HNSWGraph::insert(Item &q) {
-    int nid = items.size();
+    int nid = static_cast<int>(items.size());
     numItem++;
     items.push_back(q);
     // sample layer
-    int maxLyer = layerEdgeLists.size() - 1;
+    int maxLayer = static_cast<int>(layerEdgeLists.size()) - 1;
     int l = 0;
     uniform_real_distribution<double> distribution(0.0, 1.0);
     while (l < ML && (1.0 / ML <= distribution(generator))) {
@@ -30,8 +29,8 @@ void HNSWGraph::insert(Item &q) {
     }
     // search up layer entrance
     int ep = enterNode;
-    for (int i = maxLyer; i > l; i--) ep = searchLayer(q, ep, 1, i)[0];
-    for (int i = min(l, maxLyer); i >= 0; i--) {
+    for (int i = maxLayer; i > l; i--) ep = searchLayer(q, ep, 1, i)[0];
+    for (int i = min(l, maxLayer); i >= 0; i--) {
         int MM = l == 0 ? MNZ : MN;
         vector<int> neighbors = searchLayer(q, ep, SN, i);
         vector<int> selectedNeighbors = vector<int>(neighbors.begin(),
@@ -53,23 +52,24 @@ void HNSWGraph::insert(Item &q) {
     if (l == layerEdgeLists.size() - 1) enterNode = nid;
 }
 
+vector<int> HNSWGraph::search(Item &q, int K) {
+    int maxLayer = static_cast<int>(layerEdgeLists.size()) - 1;
+    int ep = enterNode;
+    for (auto l = maxLayer; l >= 1; l--) ep = searchLayer(q, ep, 1, l)[0];
+    return searchLayer(q, ep, K, 0);
+}
+
 void HNSWGraph::addEdge(int st, int ed, int lc) {
     if (st == ed) return;
     layerEdgeLists[lc][st].push_back(ed);
     layerEdgeLists[lc][ed].push_back(st);
 }
 
-vector<int> HNSWGraph::search(Item &q, int K) {
-    auto maxLyer = layerEdgeLists.size() - 1;
-    int ep = enterNode;
-    for (auto l = maxLyer; l >= 1; l--) ep = searchLayer(q, ep, 1, l)[0];
-    return searchLayer(q, ep, K, 0);
-}
-
 vector<int> HNSWGraph::searchLayer(Item &q, int ep, int ef, int lc) {
+    unordered_set<int> isVisited;
     set<pair<double, int>> candidates;
     set<pair<double, int>> nearestNeighbors;
-    unordered_set<int> isVisited;
+
 
     double td = q.dist(items[ep], key, size, bound);
     candidates.insert(make_pair(td, ep));
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index fb3fe9e..a5e1b85 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -5,6 +5,7 @@ add_executable(test_vector test_vector.cpp)
 add_executable(test_matrix test_matrix.cpp)
 add_executable(test_ipre test_ipre.cpp)
 add_executable(test_helper test_helper.cpp)
+add_executable(test_hnsw test_hnsw.cpp)
 
 # Link tests to the main library.
 target_link_libraries(test_field PRIVATE ppann_lib)
@@ -13,6 +14,7 @@ target_link_libraries(test_vector PRIVATE ppann_lib)
 target_link_libraries(test_matrix PRIVATE ppann_lib)
 target_link_libraries(test_ipre PRIVATE ppann_lib)
 target_link_libraries(test_helper PRIVATE ppann_lib)
+target_link_libraries(test_hnsw PRIVATE ppann_lib)
 
 # Register the previous tests.
 add_test(NAME test_field COMMAND test_field)
@@ -20,4 +22,5 @@ add_test(NAME test_group COMMAND test_group)
 add_test(NAME test_vector COMMAND test_vector)
 add_test(NAME test_matrix COMMAND test_matrix)
 add_test(NAME test_ipre COMMAND test_ipre)
-add_test(NAME test_helper COMMAND test_helper WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/data")
\ No newline at end of file
+add_test(NAME test_helper COMMAND test_helper WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/data")
+add_test(NAME test_hnsw COMMAND test_hnsw WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/data")
\ No newline at end of file
diff --git a/tests/test_helper.cpp b/tests/test_helper.cpp
index 1ea3e23..8efa972 100644
--- a/tests/test_helper.cpp
+++ b/tests/test_helper.cpp
@@ -59,10 +59,10 @@ int test_encrypt() {
 
     // Encrypt the first two vectors.
     Key key = setup(d_int);
-    Ct *encrypted_data = encrypt_data(data, key, d, 2);
+    Item *encrypted_data = encrypt_data(data, key, d, 2);
 
     // Get inner product of the first two vectors.
-    int result = eval(key, encrypted_data[0], encrypted_data[1], d_int, 200000);
+    int result = eval(key, encrypted_data[0].value, encrypted_data[1].value, d_int, 200000);
 
     // Check for whether the data is correct.
     if (result != 184094) return 0;
diff --git a/tests/test_hnsw.cpp b/tests/test_hnsw.cpp
new file mode 100644
index 0000000..29e196c
--- /dev/null
+++ b/tests/test_hnsw.cpp
@@ -0,0 +1,57 @@
+#include "ipre.h"
+#include "hnsw.h"
+#include "helper.h"
+
+int test_constructor() {
+    // Get the key.
+    Key key = setup(10);
+
+    // Get the HNSW object.
+    HNSWGraph index(10, 30, 10, 40, 4, key, 10, 100);
+
+    // Check for whether the data is correct.
+    if (index.NN != 10) return 0;
+
+    // If everything passes, return 1.
+    return 1;
+}
+
+int test_insert() {
+    // Set dimensions holders and get the data.
+    size_t d, n;
+    float *xd = fvecs_read("sift_query.fvecs", &d, &n);
+
+    // Conversion.
+    int *data = float_to_int(xd, d * n);
+
+    // Cast d to integer.
+    int d_int = static_cast<int>(d);
+
+    // Encrypt the first three vectors.
+    Key key = setup(d_int);
+    Item *encrypted_data = encrypt_data(data, key, d, 3);
+
+    // Get the HNSW object.
+    HNSWGraph index(10, 30, 10, 40, 4, key, 10, 100);
+
+    // Do the insert.
+    for (int i = 0; i < 3; i++) index.insert(encrypted_data[i]);
+
+    // Check for whether the data is correct.
+    if (index.numItem != 3) return 0;
+
+    return 1;
+}
+
+
+int main() {
+    // Init core and setup.
+    core_init();
+    pc_param_set_any();
+
+    // Perform tests.
+    if (test_constructor() != 1) return 1;
+    if (test_insert() != 1) return 1;
+
+    return 0;
+}
\ No newline at end of file
-- 
GitLab