Lazy Propagation - Too lazy to update all at a time
Sometimes problems will ask you to update an interval from l to r, instead of a single element. One solution is to update all the elements one by one. Complexity of this approach will be O(N) per operation since where are N elements in the array and updating a single element will take O(logN) time.
To avoid multiple call to update function, we can modify the update function to work on an interval.
For example consider the node with value 27 in above diagram, this node stores sum of values at indexes from 3 to 5. If our update query is for range 2 to 5, then we need to update this node and all descendants of this node. With Lazy propagation, we update only node with value 27 and postpone updates to its children by storing this update information in separate nodes called lazy nodes or values. We create an array lazy[] which represents lazy node. Size of lazy[] is same as array that represents segment tree, which is tree[] in below code.
The idea is to initialize all elements of lazy[] as 0. A value 0 in lazy[i] indicates that there are no pending updates on node i in segment tree. A non-zero value of lazy[i] means that this amount needs to be added to node i in segment tree before making any query to the node.
To update an interval we will keep 3 things in mind.
- If current segment tree node has any pending update, then first add that pending update to current node.
- If the interval represented by current node lies completely in the interval to update, then update the current node and update the lazy[] array for children nodes.
- If the interval represented by current node overlaps with the interval to update, then update the nodes as the earlier update function.
Operations(Templates):-
UpdateRange:-
void updateRange(int idx, int start, int end, int l, int r, int val)
{
if(lazy[idx] != 0) // 1 {
// This node needs to be updated
tree[idx] += (end - start + 1) * lazy[idx]; // Update it if(start != end)
{
lazy[idx*2] += lazy[idx]; // Mark child as lazy lazy[idx*2+1] += lazy[idx]; // Mark child as lazy }
lazy[idx] = 0; // Reset it }
if(start > end or start > r or end < l) // Outside Range
return;
if(start >= l and end <= r) // 2
{
// Segment is fully within range
tree[idx] += (end - start + 1) * val; if(start != end)
{
// Not leaf node
lazy[idx*2] += val; lazy[idx*2+1] += val; }
return;
}
int mid = (start + end) / 2; // 3
updateRange(idx*2, start, mid, l, r, val); // Updating left child updateRange(idx*2 + 1, mid + 1, end, l, r, val); // Updating right child tree[idx] = tree[idx*2] + tree[idx*2+1]; // Updating root with max value }
QueryRange:-
int queryRange(int idx, int start, int end, int l, int r)
{
if(start > end or start > r or end < l)
return 0; // Out of range
if(lazy[idx] != 0) {
// This node needs to be updated
tree[idx] += (end - start + 1) * lazy[idx]; // Update it if(start != end)
{
lazy[idx*2] += lazy[idx]; // Mark child as lazy lazy[idx*2+1] += lazy[idx]; // Mark child as lazy }
lazy[idx] = 0; // Reset it }
if(start >= l and end <= r) // Current segment is totally within range [l, r]
return tree[nodidx]; int mid = (start + end) / 2;
int p1 = queryRange(idx*2, start, mid, l, r); // Query left child int p2 = queryRange(idx*2 + 1, mid + 1, end, l, r); // Query right child return (p1 + p2);
}
Doubts ?
Reference Links:
Lets Jump to some questions :
Question 1:
Given an array of N numbers
1: Increment the elements within range [i, j] with value val
2: Get max element within range [i, j]
Question 1:
Given an array of N numbers
1: Increment the elements within range [i, j] with value val
2: Get max element within range [i, j]
Solution:
#include<bits/stdc++.h> using namespace std; #define N 100005 int a[N]; int seg[3*N]; int lazy[3*N]; void build(int idx, int st, int end) { if(st==end) { seg[idx] = a[st]; // Init value return; } int mid=(st+end)/2; build(idx*2,st,mid); build(idx*2+1,mid+1,end); seg[idx] = max(seg[idx*2], seg[idx*2+1]); } void update(int idx, int st, int end, int qs, int qe, int val) { if(lazy[idx] != 0) { seg[idx] += lazy[idx]; if(st!=end) { lazy[idx*2] += lazy[idx]; // Mark child as lazy lazy[idx*2+1] += lazy[idx]; // Mark child as lazy } lazy[idx] = 0; } if(st > end || st > qe || end < qs) return; if(st >= qs && end <= qe) { // Segment is fully within range seg[idx] += val; if(st!=end) { lazy[idx*2] += val; lazy[idx*2+1] += val; } return; } int mid=(st+end)/2; update(idx*2,st,mid,qs,qe,val); update(idx*2+1,mid+1,end,qs,qe,val); seg[idx] = max(seg[idx*2], seg[idx*2+1]); } int query(int idx,int st,int end,int qs,int qe) { if(st > end || st > qe || end < qs) return 0; if(lazy[idx] != 0) { seg[idx] += lazy[idx]; if(st!=end) { lazy[idx*2] += lazy[idx]; // Mark child as lazy lazy[idx*2+1] += lazy[idx]; // Mark child as lazy } lazy[idx] = 0; // Reset it } if(st >= qs && end <= qe) return seg[idx]; int mid=(st+end)/2; int q1 = query(idx*2,st,mid,qs,qe); int q2 = query(idx*2+1,1+mid,end,qs,qe); int res = max(q1, q2); return res; }
N coins numbered 0 to N-1 all tails up initially.
0. Flip all coins numbered between A and B inclusive
1. How many coins numbered between A and B inclusive are heads up
Input Output
4 7 0 1 0 3 1 0 1 2 0 1 0 1 2 1 0 0 1 0 0 3 1 0 3 1 3 3
Solution:
#include <bits/stdc++.h> using namespace std; int seg[272143],lazy[272143]; void build(int idx,int st,int end){ if(st>end) return; if(st==end){ seg[idx]=0; return; } build(idx*2,st,(st+end)/2); build(idx*2+1,(st+end)/2+1,end); seg[idx]=seg[idx*2]+seg[idx*2+1]; } void updateRange(int idx,int st,int end,int qs,int qe){ if(st>end || qs>qe) return ; if(lazy[idx]!=0){ seg[idx]=(end-st+1)-seg[idx]; if(st!=end){ lazy[idx*2]=1-lazy[idx*2]; lazy[idx*2+1]=1-lazy[idx*2+1]; } lazy[idx]=0; } if(qs>end || qe<st ) return; if(st>=qs && end<=qe){ seg[idx]=(end-st+1)-seg[idx]; if(st!=end){ lazy[idx*2]=1-lazy[idx*2]; lazy[idx*2+1]=1-lazy[idx*2+1]; } return; //Important } update(idx*2,st,(st+end)/2,qs,qe); update(idx*2+1,1+(st+end)/2,end,qs,qe); seg[idx]=seg[idx*2]+seg[idx*2+1]; } int queryRange(int idx,int st,int end,int qs,int qe){ if(st>end || qs>end || qe<st) return 0; if(lazy[idx]!=0){ seg[idx]=(end-st+1)-seg[idx]; if(st!=end){ lazy[idx*2]=1-lazy[idx*2]; lazy[idx*2+1]=1-lazy[idx*2+1]; } lazy[idx]=0; } if(st>=qs && end<=qe) return seg[idx]; int left = query(idx*2,st,(st+end)/2,qs,qe); int right =query(idx*2+1,1+(st+end)/2,end,qs,qe); return (left+right); } int main() { int n,type,q,x,y; scanf("%d%d",&n,&q); build(1,0,n-1); while(q--){ scanf("%d%d%d",&type,&x,&y); if(type==0){ updateRange(1,0,n-1,x,y); } else{ printf("%d\n",queryRange(1,0,n-1,x,y)); } } }
Question 3: Link to Question
2 st nd -- return the sum of the squares of the numbers with indices in [st, nd]1 st nd x -- add "x" to all numbers with indices in [st, nd]0 st nd x -- set all numbers with indices in [st, nd] to "x"
Input Output
2 Case 1: 4 5 30 1 2 3 4 7 2 1 4 13 0 3 4 1 Case 2: 2 1 4 1 1 3 4 1 2 1 4 1 1 1 2 1 1
Solution:#include <bits/stdc++.h> using namespace std; #define optimizar_io ios_base::sync_with_stdio(0);cin.tie(0); #define N 100005 #define ll long long int struct node { ll sum,sqrsum,lazy,type; }seg[3*N]; int a[N]; void build(int idx,int st,int end){ if(st>end) return; if(st==end){ seg[idx].type=0; seg[idx].lazy=0; seg[idx].sum=a[st]; seg[idx].sqrsum=a[st]*a[st]; return; } seg[idx].type=0; seg[idx].lazy=0; build(idx*2,st,(st+end)/2); build(idx*2+1,(st+end)/2+1,end); seg[idx].sum=seg[idx*2].sum+seg[idx*2+1].sum; seg[idx].sqrsum=seg[idx*2].sqrsum+seg[idx*2+1].sqrsum; } void updateRange(int idx,int st,int end,int qs,int qe,int val,int type){ if(seg[idx].type!=0){ if(seg[idx].type==1){ seg[idx].sqrsum+=2*seg[idx].sum*seg[idx].lazy+(end-st+1)*seg[idx].lazy*seg[idx].lazy; seg[idx].sum+=(end-st+1)*seg[idx].lazy; } else if(seg[idx].type==2){ seg[idx].sqrsum=(end-st+1)*seg[idx].lazy*seg[idx].lazy; seg[idx].sum=(end-st+1)*seg[idx].lazy; } if(st!=end){ seg[2*idx].type=seg[idx].type; seg[2*idx].lazy=seg[idx].lazy; seg[2*idx+1].type=seg[idx].type; seg[2*idx+1].lazy=seg[idx].lazy; } seg[idx].lazy=0; seg[idx].type=0; } if(st>end || st>qe || end<qs) return; if(st>=qs && end<=qe){ if(type==1){ seg[idx].sqrsum+=2*seg[idx].sum*val+(end-st+1)*val*val; seg[idx].sum+=(end-st+1)*val; } else if(type==2){ seg[idx].sqrsum=(end-st+1)*val*val; seg[idx].sum=(end-st+1)*val; } if(st!=end){ seg[2*idx].type=type; seg[2*idx].lazy=val; seg[2*idx+1].type=type; seg[2*idx+1].lazy=val; } return; } int mid=(st+end)/2; updateRange(2*idx,st,mid,qs,qe,val,type); updateRange(2*idx+1,mid+1,end,qs,qe,val,type); seg[idx].sqrsum=seg[2*idx].sqrsum+seg[2*idx+1].sqrsum; seg[idx].sum=seg[2*idx].sum+seg[2*idx+1].sum; } int queryRange(int idx,int st,int end,int qs,int qe){ if(st>end || qs>end || qe<st) return 0; if(seg[idx].type!=0){ if(seg[idx].type==1){ seg[idx].sqrsum+=2*seg[idx].sum*seg[idx].lazy+(end-st+1)*seg[idx].lazy*seg[idx].lazy; seg[idx].sum+=(end-st+1)*seg[idx].lazy; } else if(seg[idx].type==2){ seg[idx].sqrsum=(end-st+1)*seg[idx].lazy*seg[idx].lazy; seg[idx].sum=(end-st+1)*seg[idx].lazy; } if(st!=end){ seg[2*idx].type=seg[idx].type; seg[2*idx].lazy=seg[idx].lazy; seg[2*idx+1].type=seg[idx].type; seg[2*idx+1].lazy=seg[idx].lazy; } seg[idx].lazy=0; seg[idx].type=0; } if(st>=qs && end<=qe) return seg[idx].sqrsum; int left = queryRange(idx*2,st,(st+end)/2,qs,qe); int right =queryRange(idx*2+1,1+(st+end)/2,end,qs,qe); return (left+right); } int main() { int t,n,q,qs,qe,val,x,type; optimizar_io cin>>t; for(int i=1;i<=t;i++){ cout<<"Case "<<i<<":"<<endl; cin>>n>>q; for(int i=0;i<n;i++) cin>>a[i]; build(1,0,n-1); while(q--){ cin>>type>>qs>>qe; if(type==0){ cin>>val; updateRange(1,0,n-1,qs-1,qe-1,val,2); } else if(type==1){ cin>>val; updateRange(1,0,n-1,qs-1,qe-1,val,1); } else{ cout<<queryRange(1,0,n-1,qs-1,qe-1)<<endl; } } } return 0; }
Practice Problems:
CodeForcesProblem SPOJ
Comments
Post a Comment