summaryrefslogtreecommitdiffstats
path: root/02-data_structures/04-binary_search_trees/02-set_range_sum
diff options
context:
space:
mode:
Diffstat (limited to '02-data_structures/04-binary_search_trees/02-set_range_sum')
-rw-r--r--02-data_structures/04-binary_search_trees/02-set_range_sum/set_range_sum.cpp47
1 files changed, 38 insertions, 9 deletions
diff --git a/02-data_structures/04-binary_search_trees/02-set_range_sum/set_range_sum.cpp b/02-data_structures/04-binary_search_trees/02-set_range_sum/set_range_sum.cpp
index 8de8e72..c634094 100644
--- a/02-data_structures/04-binary_search_trees/02-set_range_sum/set_range_sum.cpp
+++ b/02-data_structures/04-binary_search_trees/02-set_range_sum/set_range_sum.cpp
@@ -2,6 +2,8 @@
// Splay tree implementation
+bool debug = false;
+
// Vertex of a splay tree
struct Vertex {
int key;
@@ -16,6 +18,16 @@ struct Vertex {
: key(key), sum(sum), left(left), right(right), parent(parent) {}
};
+
+void print(Vertex *v, int l) {
+ if (v == NULL) return;
+ for (int i = 0; i < l; i++)
+ printf(" ");
+ printf("%p [%d] ->%p -> %p\n", v, v->key, v->left, v->right);
+ if (v->left != NULL) print(v->left, l+1);
+ if (v->right != NULL) print(v->right, l+1);
+}
+
void update(Vertex* v) {
if (v == NULL) return;
v->sum = v->key + (v->left != NULL ? v->left->sum : 0ll) + (v->right != NULL ? v->right->sum : 0ll);
@@ -147,6 +159,7 @@ Vertex* merge(Vertex* left, Vertex* right) {
Vertex* root = NULL;
void insert(int x) {
+ if (debug) printf("insert %d\n", x);
Vertex* left = NULL;
Vertex* right = NULL;
Vertex* new_vertex = NULL;
@@ -158,26 +171,41 @@ void insert(int x) {
}
void erase(int x) {
- // Implement erase yourself
-
+ if (debug) printf("erase %d\n", x);
+ Vertex *v = find(root, x);
+ if (v == NULL)
+ return;
+ splay(root, v);
+ if (v->left != NULL)
+ v->left->parent = NULL;
+ if (v->right != NULL)
+ v->right->parent = NULL;
+ root = merge(v->left, v->right);
}
-bool find(int x) {
- // Implement find yourself
-
- return false;
+bool find(int x) {
+ if (debug) printf("find %d\n", x);
+ Vertex *v = find(root, x);
+ if (v == NULL) return false;
+ if (v != NULL && root != v)
+ splay(root, v);
+ return (v->key == x);
}
long long sum(int from, int to) {
+ if (debug) printf("sum %d - %d\n", from, to);
Vertex* left = NULL;
Vertex* middle = NULL;
Vertex* right = NULL;
split(root, from, left, middle);
split(middle, to + 1, middle, right);
long long ans = 0;
- // Complete the implementation of sum
-
- return ans;
+ if (middle != NULL) {
+ update(middle);
+ ans = middle->sum;
+ }
+ root = merge(merge(left, middle), right);
+ return ans;
}
const int MODULO = 1000000001;
@@ -214,6 +242,7 @@ int main(){
last_sum_result = int(res % MODULO);
}
}
+ if (debug) print(root, 4);
}
return 0;
}