Lazy Propagation - Too lazy to update all at a time

Sometimes problems will ask you to update an interval from l to r, instead of a single element. One solution is to update all the elements one by one. Complexity of this approach will be O(N) per operation since where are N elements in the array and updating a single element will take O(logN) time.

To avoid multiple call to update function, we can modify the update function to work on an interval.

For example consider the node with value 27 in above diagram, this node stores sum of values at indexes from 3 to 5. If our update query is for range 2 to 5, then we need to update this node and all descendants of this node. With Lazy propagation, we update only node with value 27 and postpone updates to its children by storing this update information in separate nodes called lazy nodes or values. We create an array lazy[] which represents lazy node. Size of lazy[] is same as array that represents segment tree, which is tree[] in below code.


The idea is to initialize all elements of lazy[] as 0. A value 0 in lazy[i] indicates that there are no pending updates on node i in segment tree. A non-zero value of lazy[i] means that this amount needs to be added to node i in segment tree before making any query to the node.

To update an interval we will keep 3 things in mind.
  1. If current segment tree node has any pending update, then first add that pending update to current node.
  2. If the interval represented by current node lies completely in the interval to update, then update the current node and update the lazy[] array for children nodes.
  3. If the interval represented by current node overlaps with the interval to update, then update the nodes as the earlier update function.
Operations(Templates):-


UpdateRange:-
void updateRange(int idx, int start, int end, int l, int r, int val)
{
    if(lazy[idx] != 0)                                // 1    { 
        // This node needs to be updated
        tree[idx] += (end - start + 1) * lazy[idx];    // Update it        if(start != end)
        {
            lazy[idx*2] += lazy[idx];                  // Mark child as lazy            lazy[idx*2+1] += lazy[idx];                // Mark child as lazy        }
        lazy[idx] = 0;                                  // Reset it    }
    if(start > end or start > r or end < l)              // Outside Range
        return;
    if(start >= l and end <= r)                         // 2  
    {
        // Segment is fully within range
        tree[idx] += (end - start + 1) * val;        if(start != end)
        {
            // Not leaf node
            lazy[idx*2] += val;            lazy[idx*2+1] += val;        }
        return;
    }
    int mid = (start + end) / 2;                     // 3
    updateRange(idx*2, start, mid, l, r, val);        // Updating left child    updateRange(idx*2 + 1, mid + 1, end, l, r, val);   // Updating right child    tree[idx] = tree[idx*2] + tree[idx*2+1];        // Updating root with max value }



QueryRange:-
int queryRange(int idx, int start, int end, int l, int r)
{
    if(start > end or start > r or end < l)
        return 0;         // Out of range
    if(lazy[idx] != 0)    {
        // This node needs to be updated
        tree[idx] += (end - start + 1) * lazy[idx];            // Update it        if(start != end)
        {
            lazy[idx*2] += lazy[idx];         // Mark child as lazy            lazy[idx*2+1] += lazy[idx];    // Mark child as lazy        }
        lazy[idx] = 0;                 // Reset it    }
    if(start >= l and end <= r)             // Current segment is totally within range [l, r]
        return tree[nodidx];    int mid = (start + end) / 2;
    int p1 = queryRange(idx*2, start, mid, l, r);         // Query left child    int p2 = queryRange(idx*2 + 1, mid + 1, end, l, r); // Query right child    return (p1 + p2);
}



Doubts ?

Reference Links:

Lets Jump to some questions :

Question 1:

Given an array of N numbers

1: Increment the elements within range [i, j] with value val

2: Get max element within range [i, j]



Solution:
#include<bits/stdc++.h>
using namespace std;

#define N 100005

int a[N];
int seg[3*N];
int lazy[3*N];

void build(int idx, int st, int end) {
   if(st==end) { 
      seg[idx] = a[st]; // Init value
  return;
 }
 int mid=(st+end)/2;
 build(idx*2,st,mid); 
 build(idx*2+1,mid+1,end); 
 
 seg[idx] = max(seg[idx*2], seg[idx*2+1]); 
}

void update(int idx, int st, int end, int qs, int qe, int val) {
  
   if(lazy[idx] != 0) { 
     seg[idx] += lazy[idx]; 

  if(st!=end) {
   lazy[idx*2] += lazy[idx]; // Mark child as lazy
      lazy[idx*2+1] += lazy[idx]; // Mark child as lazy
  }

     lazy[idx] = 0; 
   }
  
 if(st > end || st > qe || end < qs) 
  return;
    
   if(st >= qs && end <= qe) { // Segment is fully within range
      seg[idx] += val;

  if(st!=end) {
   lazy[idx*2] += val;
      lazy[idx*2+1] += val;
  }

     return;
 }
    int mid=(st+end)/2;
 update(idx*2,st,mid,qs,qe,val); 
 update(idx*2+1,mid+1,end,qs,qe,val);

 seg[idx] = max(seg[idx*2], seg[idx*2+1]); 
}

int query(int idx,int st,int end,int qs,int qe) {
 
 if(st > end || st > qe || end < qs) 
  return 0;

 if(lazy[idx] != 0) { 
  seg[idx] += lazy[idx]; 

  if(st!=end) {
   lazy[idx*2] += lazy[idx]; // Mark child as lazy
      lazy[idx*2+1] += lazy[idx]; // Mark child as lazy
  }

  lazy[idx] = 0; // Reset it
 }

 if(st >= qs && end <= qe) 
  return seg[idx];
  
 int mid=(st+end)/2;
 int q1 = query(idx*2,st,mid,qs,qe); 
 int q2 = query(idx*2+1,1+mid,end,qs,qe);

 int res = max(q1, q2); 
 
 return res;
}

Question 2Link to Question

N coins numbered 0 to N-1 all tails up initially.

0. Flip all coins numbered between A and B inclusive
1. How many coins numbered between A and B inclusive are heads up

Input                                                             Output

4 7                                  0
1 0 3                                1
0 1 2                                0
1 0 1                                2
1 0 0                                1
0 0 3
1 0 3 
1 3 3
Solution:
#include <bits/stdc++.h>
using namespace std;

int seg[272143],lazy[272143];
void build(int idx,int st,int end){
 if(st>end)
  return;
 if(st==end){
  seg[idx]=0;
  return;
 }
 build(idx*2,st,(st+end)/2);
 build(idx*2+1,(st+end)/2+1,end);
 
 seg[idx]=seg[idx*2]+seg[idx*2+1];
}
void updateRange(int idx,int st,int end,int qs,int qe){
 if(st>end || qs>qe) 
   return ;
 if(lazy[idx]!=0){
  seg[idx]=(end-st+1)-seg[idx];
  if(st!=end){
   lazy[idx*2]=1-lazy[idx*2];
   lazy[idx*2+1]=1-lazy[idx*2+1];
  }
  lazy[idx]=0;
 }
 if(qs>end || qe<st )
  return;
 if(st>=qs && end<=qe){
  seg[idx]=(end-st+1)-seg[idx];
  if(st!=end){
   lazy[idx*2]=1-lazy[idx*2];
   lazy[idx*2+1]=1-lazy[idx*2+1];
  }
  return;     //Important
 }
 update(idx*2,st,(st+end)/2,qs,qe);
 update(idx*2+1,1+(st+end)/2,end,qs,qe);
 seg[idx]=seg[idx*2]+seg[idx*2+1];
}

int queryRange(int idx,int st,int end,int qs,int qe){
 if(st>end || qs>end || qe<st)
   return 0;
 if(lazy[idx]!=0){
  seg[idx]=(end-st+1)-seg[idx];
  if(st!=end){
   lazy[idx*2]=1-lazy[idx*2];
   lazy[idx*2+1]=1-lazy[idx*2+1];
  }
  lazy[idx]=0;
 }
 if(st>=qs && end<=qe)
  return seg[idx];
 
 int left = query(idx*2,st,(st+end)/2,qs,qe);
 int right =query(idx*2+1,1+(st+end)/2,end,qs,qe);
 return (left+right);
}

int main() {
 int n,type,q,x,y;
 scanf("%d%d",&n,&q); 
 build(1,0,n-1);
 while(q--){
  scanf("%d%d%d",&type,&x,&y);
  if(type==0){
   updateRange(1,0,n-1,x,y);
  }
  else{
   printf("%d\n",queryRange(1,0,n-1,x,y));
  }
 }
}
Question 3Link to Question
2 st nd -- return the sum of the squares of the numbers with indices in [st, nd]
1 st nd x -- add "x" to all numbers with indices in [st, nd] 
0 st nd x -- set all numbers with indices in [st, nd] to "x"
Input                             Output
2                                 Case 1:
4 5                               30
1 2 3 4                           7   
2 1 4                             13
0 3 4 1                           Case 2:
2 1 4                             1
1 3 4 1
2 1 4
1 1
1
2 1 1
Solution:
#include <bits/stdc++.h>
using namespace std;

#define optimizar_io ios_base::sync_with_stdio(0);cin.tie(0);

#define N 100005
#define ll long long int

struct node 
{
 ll sum,sqrsum,lazy,type;
}seg[3*N];

int a[N];

void build(int idx,int st,int end){
 if(st>end)
  return;
 if(st==end){
  seg[idx].type=0;
  seg[idx].lazy=0;
  seg[idx].sum=a[st];
  seg[idx].sqrsum=a[st]*a[st];
  return;
 }
 seg[idx].type=0;
 seg[idx].lazy=0;
 build(idx*2,st,(st+end)/2);
 build(idx*2+1,(st+end)/2+1,end);
 
 seg[idx].sum=seg[idx*2].sum+seg[idx*2+1].sum;
 seg[idx].sqrsum=seg[idx*2].sqrsum+seg[idx*2+1].sqrsum;
}

void updateRange(int idx,int st,int end,int qs,int qe,int val,int type){
 if(seg[idx].type!=0){
  if(seg[idx].type==1){
   seg[idx].sqrsum+=2*seg[idx].sum*seg[idx].lazy+(end-st+1)*seg[idx].lazy*seg[idx].lazy;
   seg[idx].sum+=(end-st+1)*seg[idx].lazy;
  }
  else if(seg[idx].type==2){
   seg[idx].sqrsum=(end-st+1)*seg[idx].lazy*seg[idx].lazy;
   seg[idx].sum=(end-st+1)*seg[idx].lazy;
  }
  if(st!=end){
   seg[2*idx].type=seg[idx].type;
   seg[2*idx].lazy=seg[idx].lazy;
   seg[2*idx+1].type=seg[idx].type;
   seg[2*idx+1].lazy=seg[idx].lazy;
  }
  seg[idx].lazy=0;
  seg[idx].type=0;
 }
 if(st>end || st>qe || end<qs)
  return;
 if(st>=qs && end<=qe){
  if(type==1){
   seg[idx].sqrsum+=2*seg[idx].sum*val+(end-st+1)*val*val;
   seg[idx].sum+=(end-st+1)*val;
  }
  else if(type==2){
   seg[idx].sqrsum=(end-st+1)*val*val;
   seg[idx].sum=(end-st+1)*val;
  }
  if(st!=end){
   seg[2*idx].type=type;
   seg[2*idx].lazy=val;
   seg[2*idx+1].type=type;
   seg[2*idx+1].lazy=val;
  }
  return;
 }
 int mid=(st+end)/2;
 updateRange(2*idx,st,mid,qs,qe,val,type);
 updateRange(2*idx+1,mid+1,end,qs,qe,val,type);
 seg[idx].sqrsum=seg[2*idx].sqrsum+seg[2*idx+1].sqrsum;
 seg[idx].sum=seg[2*idx].sum+seg[2*idx+1].sum;
}

int queryRange(int idx,int st,int end,int qs,int qe){
 if(st>end || qs>end || qe<st)
   return 0;
 if(seg[idx].type!=0){
  if(seg[idx].type==1){
   seg[idx].sqrsum+=2*seg[idx].sum*seg[idx].lazy+(end-st+1)*seg[idx].lazy*seg[idx].lazy;
   seg[idx].sum+=(end-st+1)*seg[idx].lazy;
  }
  else if(seg[idx].type==2){
   seg[idx].sqrsum=(end-st+1)*seg[idx].lazy*seg[idx].lazy;
   seg[idx].sum=(end-st+1)*seg[idx].lazy;
  }
  if(st!=end){
   seg[2*idx].type=seg[idx].type;
   seg[2*idx].lazy=seg[idx].lazy;
   seg[2*idx+1].type=seg[idx].type;
   seg[2*idx+1].lazy=seg[idx].lazy;
  }
  seg[idx].lazy=0;
  seg[idx].type=0;
 }
 if(st>=qs && end<=qe)
  return seg[idx].sqrsum;
 
 int left = queryRange(idx*2,st,(st+end)/2,qs,qe);
 int right =queryRange(idx*2+1,1+(st+end)/2,end,qs,qe);
 return (left+right);
}

int main() {
 int t,n,q,qs,qe,val,x,type;
 optimizar_io
 cin>>t;
 for(int i=1;i<=t;i++){
  cout<<"Case "<<i<<":"<<endl;
  cin>>n>>q;
  for(int i=0;i<n;i++)
   cin>>a[i];
  build(1,0,n-1);
  while(q--){
   cin>>type>>qs>>qe;
   if(type==0){
    cin>>val;
    updateRange(1,0,n-1,qs-1,qe-1,val,2);
   }
   else if(type==1){
    cin>>val;
    updateRange(1,0,n-1,qs-1,qe-1,val,1);
   }
   else{
    cout<<queryRange(1,0,n-1,qs-1,qe-1)<<endl;
   }
  }
 }
 return 0;
}

Practice Problems:
  1. spoj.com HORRIBLE
  2. spoj.com LITE
  3. spoj.com MULTQ3
  4. codechef FUNAGP (difficult)
CodeForces
Problem SPOJ




















Comments

Popular posts from this blog

Combinatorial Game Theory

Bit Manipulation Problems