def merge(start, end, prefix):
n = end - start + 1
mid = (start + end) // 2
temp = [0] * n
i, j, k = start, mid + 1, 0
# Merge the two subarrays in sorted order
while i <= mid and j <= end:
if prefix[i] <= prefix[j]:
temp[k] = prefix[i]
k += 1
i += 1
else:
temp[k] = prefix[j]
k += 1
j += 1
# Copy the remaining elements of the left subarray, if any
while i <= mid:
temp[k] = prefix[i]
k += 1
i += 1
# Copy the remaining elements of the right subarray, if any
while j <= end:
temp[k] = prefix[j]
k += 1
j += 1
# Copy the merged subarray back to the original array
for t in range(start, end + 1):
prefix[t] = temp[t - start]
def merge_sort(start, end, prefix, lower, upper):
if start == end:
val = prefix[start]
# Check if the current element lies within the specified range
if lower <= val <= upper:
return 1 # Count the element if it's within the range
else:
return 0 # Otherwise, don't count it
mid = start + (end - start) // 2 # Calculate the middle index
ans = 0 # Initialize the inversion count
# Recursively count inversions in the left and right halves
ans += merge_sort(start, mid, prefix, lower, upper)
ans += merge_sort(mid + 1, end, prefix, lower, upper)
# Count inversions involving elements from the left and right halves
i, j, k = start, mid + 1, mid + 1
while i <= mid:
lowerbound = lower + prefix[i]
upperbound = upper + prefix[i]
# Find the index in the right half where elements are greater than the upper bound
while j <= end and prefix[j] <= upperbound:
j += 1
# Find the index in the right half where elements are less than the lower bound
while k <= end and prefix[k] < lowerbound:
k += 1
# Count the number of inversions involving elements from the left half and the right half
ans += (j - k)
i += 1
# Merge the sorted halves
merge(start, end, prefix)
return ans
def get_count(nums, lower, upper):
n = len(nums)
prefix = [0] * (n + 1) # Create a prefix sum array
for i in range(1, n + 1):
# Calculate the prefix sum at each index
prefix[i] = prefix[i - 1] + nums[i - 1]
# Perform merge sort and count inversions
return merge_sort(1, n, prefix, lower, upper)
def main():
n = 4
arr = [-2, 4, 1, -2]
lower, upper = -4, 1
answer = get_count(arr, lower, upper)
print(answer)
if __name__ == "__main__":
main()