标签:
2D Segment Tree -> Quad Tree. All in O(log4N)
class Node // 2D Segment Tree
{
public:
Node(vector<vector<int>> &m, int ix0, int iy0, int ix1, int iy1)
: sum(0), x0(ix0), x1(ix1), y0(iy0), y1(iy1),
ul(nullptr), ur(nullptr), dl(nullptr), dr(nullptr)
{
if(ix0 > ix1 || iy0 > iy1) return;
if(ix0 == ix1 && iy0 == iy1)
{
sum = m[iy0][ix0];
return;
}
int xmid = getMidX();
int ymid = getMidY();
ul = new Node(m, ix0, iy0, xmid, ymid);
sum += ul->sum;
if(ix1 > xmid)
{
ur = new Node(m, xmid + 1, iy0, ix1, ymid);
sum += ur->sum;
}
if(iy1 > ymid)
{
dl = new Node(m, ix0, ymid + 1, xmid, iy1);
sum += dl->sum;
}
if(iy1 > ymid && ix1 > xmid)
{
dr = new Node(m, xmid + 1, ymid + 1, ix1, iy1);
sum += dr->sum;
}
}
long long update(int rx, int ry, long long val)
{
if(rx == x0 && ry == y0 && x0 == x1 && y0 == y1)
{
long long d = val - sum;
sum = val;
return d;
}
int xmid = getMidX();
int ymid = getMidY();
long long d = 0;
if(rx <= xmid && ry <= ymid)
{
d = ul->update(rx, ry, val);
}
else if(rx > xmid && ry <= ymid)
{
d = ur->update(rx, ry, val);
}
else if(rx <= xmid && ry > ymid)
{
d = dl->update(rx, ry, val);
}
else if(rx > xmid && ry > ymid)
{
d = dr->update(rx, ry, val);
}
sum += d;
return d;
}
long long get(int rx0, int ry0, int rx1, int ry1)
{
if(rx0 == x0 && rx1 == x1 && ry0 == y0 && ry1 == y1)
{
return sum;
}
//
int xmid = getMidX();
int ymid = getMidY();
long long d = 0;
if(rx0 <= xmid && ry0 <= ymid)
{
d += ul->get(rx0, ry0, min(xmid, rx1), min(ymid, ry1));
}
if(rx1 > xmid && ry0 <= ymid)
{
d += ur->get(max(rx0, xmid + 1), ry0, rx1, min(ymid, ry1));
}
if(rx0 <= xmid && ry1 > ymid)
{
d += dl->get(rx0, max(ymid + 1, ry0), min(rx1, xmid), ry1);
}
if(rx1 > xmid && ry1 > ymid)
{
d += dr->get(max(rx0, xmid + 1), max(ry0, ymid + 1), rx1, ry1);
}
return d;
}
private:
int getMidX(){ return x0 + (x1 - x0) / 2; }
int getMidY(){ return y0 + (y1 - y0) / 2; }
private:
// mem vars
long long sum;
int x0, x1;
int y0, y1;
Node *ul;
Node *ur;
Node *dl;
Node *dr;
};
class NumMatrix
{
Node *pSeg;
public:
NumMatrix(vector<vector<int>> &matrix)
{
int h = matrix.size();
if(!h) return;
int w = matrix[0].size();
pSeg = new Node(matrix, 0, 0, w - 1, h - 1);
}
void update(int row, int col, int val)
{
if(pSeg)
pSeg->update(col, row, val);
}
int sumRegion(int row1, int col1, int row2, int col2)
{
if(pSeg)
return pSeg->get(col1, row1, col2, row2);
return 0;
}
};
LeetCode "Range Sum Query 2D - Mutable"
标签:
原文地址:http://www.cnblogs.com/tonix/p/4987341.html