Segment Tree
Segment Tree is used in cases where there are multiple range queries on array and modifications of elements of the same array. For example, finding the sum of all the elements in an array from indices to , or finding the minimum (famously known as Range Minumum Query problem) of all the elements in an array from indices to . These problems can be easily solved with one of the most versatile data structures, Segment Tree.
Notes:
Complexity of build() is O(N).
Update: (point)
Complexity of update will be O(logN).
Query:
Complexity of query will be O(logN).
Still having Doubts ?
Reference Links:
Lets Jump to some questions :
Question 1: Link to Question (Update in a single position)
Notes:
- The root of the segment tree contains the whole array A[0: N-1]
- There are N leaves representing theN elements of the array.
- The number of internal nodes is N−1.
- Total number of nodes are 2*N-1.
Operations(Templates):-
Build:
void build(int node, int start, int end)
{
if(start == end)
{
// Leaf node will have a single element
tree[node] = A[start];
}
else
{
int mid = (start + end) / 2;
// Recurse on the left child
build(2*node, start, mid);
// Recurse on the right child
build(2*node+1, mid+1, end);
// Internal node will have the sum of both of its children
tree[node] = tree[2*node] + tree[2*node+1];
}
}
Complexity of build() is O(N).
Update: (point)
void update(int node, int start, int end, int idx, int val)
{
if(start == end)
{
// Leaf node
A[idx] += val;
tree[node] += val;
}
else
{
int mid = (start + end) / 2;
if(start <= idx and idx <= mid)
{
// If idx is in the left child, recurse on the left child
update(2*node, start, mid, idx, val);
}
else
{
// if idx is in the right child, recurse on the right child
update(2*node+1, mid+1, end, idx, val);
}
// Internal node will have the sum of both of its children
tree[node] = tree[2*node] + tree[2*node+1];
}
}
Complexity of update will be O(logN).
Note: For range update (we generally prefer lazy propagation which we will discuss in the next post)refer this code
void update(int node, int st, int end, int qs, int qe, int val) { if(st > end || st > qe || end < qs) // Current segment is not within range [i, j] return; if(st == en) { // Leaf node tree[node] += val; return; } int mid=(st+end)/2; update_tree(node*2, st, mid, qs , qe, value); // Updating left child update_tree(1+node*2, mid+1, end, qs, qe, value); // Updating right child tree[node] = max(tree[node*2], tree[node*2+1]); // Updating root with max value }
Complexity of update will be O(N) in this case..
int query(int node, int start, int end, int l, int r)
{
if(r < start or end < l)
{
// range represented by a node is completely outside the given range
return 0;
}
if(l <= start and end <= r)
{
// range represented by a node is completely inside the given range
return tree[node];
}
// range represented by a node is partially inside and partially outside the given range
int mid = (start + end) / 2;
int p1 = query(2*node, start, mid, l, r);
int p2 = query(2*node+1, mid+1, end, l, r);
return (p1 + p2);
}
Complexity of query will be O(logN).
Still having Doubts ?
Reference Links:
Lets Jump to some questions :
Question 1: Link to Question (Update in a single position)
1. 0 x y : Update the xth element of array to y.
2. 1 v: Find the position of first element which is greater than or equal to v, and if there is no element greater than or equal to v then answer is -1.
Input: Output:
5 4 1 5 4 3 2 1 -1 1 4 3 1 6 0 3 7 1 6
Solution:
#include<bits/stdc++.h> using namespace std; #define N 100005 int a[N],seg[2*N]; void build(int idx,int st,int end) { if(st == end ) { seg[idx]=a[st]; } else{ int mid=(st+end)/2; build(2*idx,st,mid); build(2*idx+1,mid+1,end); seg[idx]=max(seg[2*idx],seg[2*idx+1]); } } void update(int idx,int st,int end,int pos,int val) { if(st==end){ a[pos]=val; seg[idx]=val; } else{ int mid=(st+end)/2; if(st<=pos && pos<=mid) update(2*idx,st,mid,pos,val); else update(2*idx+1,mid+1,end,pos,val); seg[idx]=max(seg[2*idx],seg[2*idx+1]); } } int query(int idx,int st,int end,int val){ if(st==end){ return (st+1); } int mid=(st+end)/2; if(seg[2*idx]>=val) return query(2*idx,st,mid,val); else if(seg[2*idx+1]>=val) return query(2*idx+1,mid+1,end,val); return -1; } int main(){ int n,q,t,v,x,y; cin>>n>>q; for(int i=0;i<n;i++){ cin>>a[i]; } build(1,0,n-1); while(q--){ cin>>t; if(t==1){ cin>>v; cout<<query(1,0,n-1,v)<<endl; } else{ cin>>x>>y; update(1,0,n-1,x-1,y); } } }
Question 2: Link to Question
Query 0:- modify the element present at index i to x. Query 1:- count the number of even numbers in range l to r inclusive. Query 2:- count the number of odd numbers in range l to r inclusive.
Input: Output:
6 2 1 2 3 4 5 6 2 4 4 1 2 5 2 1 4 0 5 4 1 1 6Solution:#include <bits/stdc++.h> using namespace std; #define N 100005 int seg[3*N],lazy[3*N],a[N]; void build(int idx,int st,int end){ if(st==end){ seg[idx]=a[st]; } else{ int mid=(st+end)/2; build(idx*2,st,mid); build(idx*2+1,mid+1,end); seg[idx]=seg[idx*2]+seg[idx*2+1]; } } void update(int idx,int st,int end,int pos,int val){ if(st==end){ a[pos]=val; seg[idx]=val; } else{ int mid=(st+end)/2; if(st<=pos && pos<=mid) update(idx*2,st,mid,pos,val); else update(idx*2+1,mid+1,end,pos,val); seg[idx]=seg[idx*2]+seg[idx*2+1]; } } int query(int idx,int st,int end,int qs,int qe){ if(st>end || qs>end || qe<st) return 0; if(st>=qs && end<=qe) return seg[idx]; int left = query(idx*2,st,(st+end)/2,qs,qe); int right =query(idx*2+1,1+(st+end)/2,end,qs,qe); return (left+right); } int main() { int n,type,q,x,y; scanf("%d",&n); for(int i=0;i<n;i++){ scanf("%d",&x); if(x%2) a[i]=1; } scanf("%d",&q); build(1,0,n-1); while(q--){ scanf("%d%d%d",&type,&x,&y); if(type==0){ update(1,0,n-1,x-1,y%2); } else{ int val=query(1,0,n-1,x-1,y-1); //number of odd if(type==1) val=(y-x+(y!=x?1:0))-val; printf("%d\n",val); } } }Question 3: Link to Question
Given a binary string (that is a string consisting of only 0 and 1). They were supposed to perform two types of query on the string.Type 0: Given two indices l and r.Print the value of the binary string from l to r modulo 3.Type 1: Given an index l flip the value of that index if and only if the value at that index is 0.Solution:
void fastpow() { p[0]=1; for(int i=1;i<=100000;++i) p[i]=(p[i-1]*2)%3; } void build(ll i, ll st, ll en) { if(st == en) { tree[i] = arr[st]; } else { ll mid = (st + en) / 2; build(2*i, st, mid); build(2*i+1, mid+1, en); tree[i] = (tree[2*i]*p[en-mid] + tree[2*i+1])%3; } } void update(ll i, ll st, ll en, ll idx) { if(st == en) { tree[i] = 1; arr[idx] = 1; } else { ll mid = (st + en) / 2; if(st <= idx and idx <= mid) { update(2*i, st, mid, idx); } else { update(2*i+1, mid+1, en, idx); } tree[i] = ((tree[2*i]*p[en-mid]%3) + tree[2*i+1])%3; } } ll query(ll i, ll st, ll en, ll l, ll r) { if(r < st or en < l) { return 0; } if(l <= st and en <= r) { return (tree[i]*p[r-en])%3; } ll mid = (st + en) / 2; ll p1 = query(2*i, st, mid, l, r); ll p2 = query(2*i+1, mid+1, en, l, r); return (p1 + p2)%3; }
Question 4: Link to QuestionSolution:#include<bits/stdc++.h> using namespace std; int a[1<<18],seg[1<<20]; void build(int idx,int st,int end,int bit){ if(st==end){ seg[idx]=a[st]; return; } int mid=(st+end)/2; build(2*idx,st,mid,(bit)?0:1); build(2*idx+1,mid+1,end,(bit)?0:1); if(bit) seg[idx]=seg[2*idx]^seg[2*idx+1]; else seg[idx]=seg[2*idx]|seg[2*idx+1]; } void update(int idx,int st,int end,int pos,int val,int bit){ if(st==end){ seg[idx]=val; return; } int mid=(st+end)/2; if(pos<=mid) update(2*idx,st,mid,pos,val,(bit)?0:1); else update(2*idx+1,mid+1,end,pos,val,(bit)?0:1); if(bit) seg[idx]=seg[2*idx]^seg[2*idx+1]; else seg[idx]=seg[2*idx]|seg[2*idx+1]; } int main(){ int n,m,bit=0,pos,val; cin>>n>>m; if(n%2==0) bit=1; n=1<<n; for(int i=0;i<n;i++) cin>>a[i]; build(1,0,n-1,bit); while(m--){ cin>>pos>>val; pos--; update(1,0,n-1,pos,val,bit); cout<<seg[1]<<endl; } }
Question 5: Link To Question
U i x This operation sets the value of A[i] to x.
Q x y You must find i and j such that x ≤ i, j ≤ y and i != j, such that the sum A[i]+A[j] is maximized. Print the sum A[i]+A[j].
Input Output
5 7
1 2 3 4 5 9
6 11
Q 2 4 12
Q 2 5
U 1 6
Q 1 5
U 1 7
Q 1 5
Solution:
#include <bits/stdc++.h> using namespace std; #define N 100005 int a[N]; struct tree{ int m1,m2; }seg[3*N]; void build(int idx,int st,int end){ if(st==end){ seg[idx].m1=a[st]; seg[idx].m2=0; } else{ int mid=(st+end)/2; int t[4]; build(idx*2,st,mid); build(idx*2+1,mid+1,end); t[0]=seg[idx*2].m1;t[1]=seg[idx*2].m2; t[2]=seg[idx*2+1].m1;t[3]=seg[idx*2+1].m2; sort(t,t+4); seg[idx].m1=t[3];seg[idx].m2=t[2]; } } void update(int idx,int st,int end,int pos,int val){ if(st==end){ a[st]=val; seg[idx].m1=val; seg[idx].m2=0; } else{ int mid=(st+end)/2; int t[4]; if(pos<=mid) update(idx*2,st,mid,pos,val); else update(idx*2+1,mid+1,end,pos,val); t[0]=seg[idx*2].m1;t[1]=seg[idx*2].m2; t[2]=seg[idx*2+1].m1;t[3]=seg[idx*2+1].m2; sort(t,t+4); seg[idx].m1=t[3];seg[idx].m2=t[2]; } } void query(int idx,int st,int end,int qs,int qe,int &max1,int &max2){ if((st==qs && end==qe) || (st==end)){ max1=seg[idx].m1,max2=seg[idx].m2; return; } int mid=(st+end)/2,t[4]; if(qe<=mid) query(idx*2,st,mid,qs,qe,max1,max2); else if(qs>mid) query(idx*2+1,1+mid,end,qs,qe,max1,max2); else{ query(idx*2,st,mid,qs,qe,max1,max2); t[0]=max1;t[1]=max2; query(idx*2+1,1+mid,end,qs,qe,max1,max2); t[2]=max1;t[3]=max2; sort(t,t+4); max1=t[3];max2=t[2]; } } int main() { int n,type,q,x,y; char ch; scanf("%d",&n); for(int i=0;i<n;i++){ scanf("%d",&a[i]); } scanf("%d",&q); build(1,0,n-1); while(q--){ int max1=0,max2=0; cin>>ch; //cout<<ch<<endl; if(ch=='Q'){ cin>>x>>y; //cout<<x<<y<<endl; query(1,0,n-1,x-1,y-1,max1,max2); cout<<max1+max2<<endl; } else{ cin>>x>>y; //cout<<x<<y<<endl; update(1,0,n-1,x-1,y); } } }
Practice Problems:
CodeForces
HackerEarth
Comments
Post a Comment