1 minute read

segment tree 연습용으로 찾은 문제 307. Range Sum Query - Mutable.

typedef struct st{
    int sum;
    int start, end;
    struct st* left;
    struct st* right;
}NODE;

typedef struct {
    NODE *root;
} NumArray;

NODE* addNode(int* nums, int start, int end) {
    if(start>end)
        return NULL;
    NODE *node = calloc(1, sizeof(NODE));
    node->start = start;
    node->end = end;
    if(start == end) {
        node->sum = nums[start];
        return node;
    }
    int mid = (start + end) / 2;
    node->left = addNode(nums, start, mid);
    node->right = addNode(nums, mid+1, end);
    int sum = 0;
    if(node->left)
        sum += node->left->sum;
    if(node->right)
        sum += node->right->sum;
    node->sum = sum;
    return node;
    
}

NumArray* numArrayCreate(int* nums, int numsSize) {
    NumArray* ret = calloc(1, sizeof(NumArray));
    ret->root = addNode(nums, 0, numsSize-1);
    return ret;
}

void updateNode(NODE *node, int index, int val) {
    if (node == NULL)
        return;
    if(node->start == node->end) {
        node->sum = val;
        return;
    }
    
    int start,end;
    NODE *targetNode;
    int  mid = (node->start + node->end) / 2;
    if(index <= mid){
        targetNode = node->left;
    } else {
        targetNode = node->right;
    }
    updateNode(targetNode, index, val);
    node->sum = node->left->sum + node->right->sum;
}

void numArrayUpdate(NumArray* obj, int index, int val) {
    NODE *root = obj->root;
    updateNode(root, index, val);
}

int getRangeSum(NODE *node, int left, int right) {
    if(left > right)
        return 0;
    if(node == NULL)
        return 0;
    if((node->start == left) && (right == node->end))
        return node->sum;
    int mid = (node->start + node->end) / 2;
    if(right <= mid) { /// [0-1] [0-0] [1-1] <=1-1 1+1/2 = 1 right=1 left=1 mid=1
        return getRangeSum(node->left, left, right);
    } else if (left > mid) {
        return getRangeSum(node->right,left, right);
    }
    
    int sum = getRangeSum(node->left, left, mid);
    sum += getRangeSum(node->right, mid+1, right);
    return sum;
}

int numArraySumRange(NumArray* obj, int left, int right) {
    return getRangeSum(obj->root, left, right);
}

void freeNode(NODE *node){
    if(node == NULL)
        return;
    freeNode(node->left);
    freeNode(node->right);
    free(node->left);
    free(node->right);
}

void numArrayFree(NumArray* obj) {
    freeNode(obj->root);
}

디버깅을 한참 했는데 get range sum의 int mid = (node->start + node->end) / 2;부분을 node의 start end로 안하고 left right 값으로 하고 있었다 바보… 요새 넘 올리는게 없어서 이거라도 올리기…