USACO 2015 December Contest, Platinum Problem 3. Counting Haybales




(Analysis by Nick Wu)

For this problem, we are given a list of NN integers and want to support the following three operations:

1) Compute the minimum number in an interval over the list.

2) Compute the sum of all numbers in an interval over the list.

3) Increase every number in an interval of the list by some number KK.

We start by describing a data structure which will support the first two operations in O(logN)O(log⁡N) time.

Given the initial list, we will maintain, for a specific set of intervals, the minimum number in that interval as well as the sum of all numbers in that interval. The list of intervals that we use will be as follows:

  • [1,N][1,N] is in the set.
  • If [a,b][a,b] is in the set with a<ba<b, then letting c=a+b2c=⌊a+b2⌋[a,c][a,c] and [c+1,b][c+1,b] are in the set.

Note that there are O(N)O(N) intervals in this set. Furthermore, for convenience, imagine that we are building a binary tree, where the root of the tree corresponds to the interval [1,N][1,N] and its minimum/sum, the leaves of the trees are all intervals of the form [i,i][i,i] with their minima/sums, and the left and right children of interval [a,b][a,b] are [a,c][a,c] and [c+1,b][c+1,b] respectively.

Given the minimum for intervals [a,c][a,c] and [c+1,b][c+1,b], the minimum in the interval [a,b][a,b] is just the minimum of those two. By similar logic, the sum for interval [a,b][a,b] is the sum for intervals [a,c][a,c] and [c+1,b][c+1,b].

We will now describe in detail how to support the first operation - the second operation is supported with very similar logic.

Given an interval [l,r][l,r] for which we want to compute the minimum number in that interval, generically consider how a specific interval in the tree can help us answer the first query. If the interval in that node is strictly inside our interval, then the minimum for that interval is a "candidate" for the minimum number in that interval and we don't need to inspect its children. If the interval in that node is strictly outside that interval, then we ignore that interval. Otherwise, we need to recursively look at the two children of the interval. When we are done searching through the tree, we just return the minimum candidate.

If we start this process from the root, each level of the tree can contribute a maximum of two candidates and we can only recurse to a lower level of the tree from two other nodes. If three candidates could be found from a given level of the tree, then two of those candidates would have to have a common parent and we could have short-circuited the analysis by inspecting just that interval. Similarly, the nodes where we would recurse to a lower level have to be directly to the left or the right of a valid candidate. Therefore, we only inspect O(logN)O(log⁡N) nodes, as desired.

The second operation works in essentially the same way, except we sum the "candidates".

It remains to support the third operation. If we naively update this data structure by updating every relevant node, since there are O(N)O(N) intervals, this would run in O(N) time and be no better than the naive solution of maintaining just the list of elements.

To get around this, we add another piece of information to the interval along with the minimum and the sum - the "increment", which is initialized to zero for every interval. We will maintain the invariant that if the increment of a given interval is KK, then we need to increment every number in that interval by KK. These increments stack with each other, so if interval [1,2][1,2] has an increment of 2 and interval [1,4][1,4] has an increment of 3, then the numbers in the interval [1,2][1,2] actually need to be incremented by 5.

To increment every number in a range [l,r][l,r] by K, start the same process of walking from the root down to the leaves of the tree. If we look at an interval which is outside [l,r][l,r], then we can ignore it. If the interval partially overlaps with [l,r][l,r], we have to recursively inspect the interval's children. If the interval is strictly inside the interval, then increase the increment by KK and stop there. The interesting case is when the interval partially overlaps with other intervals.

In this case, the given interval may have an increment that was requested earlier that needs to be applied to every number in the interval. In this situation, we will "increment" every number in that interval by incrementing every number in the two children by the given increment, and then set the increment in our current interval to zero. After that, we will recursively apply the current increment that was requested by recursing on the two children. Finally, to keep the data structure accurate, we will then look at the children to update the minimum and sum for the current interval.

The first two operations do not change significantly now that we must support the third operation - the major difference is that if we need to recursively inspect both children, we need to apply the increment to the two children and then update the current interval again.

By the same logic as in the analysis of the first operation, the third operation only inspects O(logN)O(log⁡N) nodes and therefore runs in O(logN)O(log⁡N) time.

This data structure is known as a range tree with lazy propagation. We call it a range tree because each node in the tree stores information about a range of the underlying list, and the lazy propagation part comes in because when we request an update to the list (in this case an increase of an interval over the tree), we update as few nodes as possible while still keeping track of the need to update more nodes.

Here is my implementation of this data structure.

import java.util.*;
public class haybales {
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new FileReader(""));
		PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter("haybales.out")));
		StringTokenizer st = new StringTokenizer(br.readLine());
		int n = Integer.parseInt(st.nextToken());
		int q = Integer.parseInt(st.nextToken());
		RangeTree tree = new RangeTree(n);
		st = new StringTokenizer(br.readLine());
		for(int i = 0; i < n; i++) {
			tree.update(i, i, Integer.parseInt(st.nextToken()));
		for(int i = 0; i < q; i++) {
			st = new StringTokenizer(br.readLine());
			String type = st.nextToken();
			int a = Integer.parseInt(st.nextToken())-1;
			int b = Integer.parseInt(st.nextToken())-1;
			int c = -1;
			if(type.equals("P")) {
				c = Integer.parseInt(st.nextToken());
			if(type.equals("M")) {
				pw.println(tree.minQuery(a, b));
			else if(type.equals("P")) {
				tree.update(a, b, c);
			else {
				pw.println(tree.sumQuery(a, b));
	static class RangeTree {
		private long[] min;
		private long[] lazy;
		private long[] sum;
		private int size;
		public RangeTree(int size) {
			this.size = size;
			min = new long[4*size];
			sum = new long[4*size];
			lazy = new long[4*size];
		public void update(int l, int r, int inc) {
			update(1, 0, size-1, l, r, inc);
		private void pushDown(int index, int l, int r) {
			min[index] += lazy[index];
			sum[index] += lazy[index] * (r-l+1);
			if(l != r) {
				lazy[2*index] += lazy[index];
				lazy[2*index+1] += lazy[index];
			lazy[index] = 0;
		private void pullUp(int index, int l, int r) {
			int m = (l+r)/2;
			min[index] = Math.min(evaluateMin(2*index, l, m), evaluateMin(2*index+1, m+1, r));
			sum[index] = evaluateSum(2*index, l, m) + evaluateSum(2*index+1, m+1, r);
		private long evaluateSum(int index, int l, int r) {
			return sum[index] + (r-l+1)*lazy[index];
		private long evaluateMin(int index, int l, int r) {
			return min[index] + lazy[index];
		private void update(int index, int l, int r, int left, int right, int inc) {
			if(r < left || l > right) return;
			if(l >= left && r <= right) {
				lazy[index] += inc;
			pushDown(index, l, r);
			int m = (l+r)/2;
			update(2*index, l, m, left, right, inc);
			update(2*index+1, m+1, r, left, right, inc);
			pullUp(index, l, r);
		public long minQuery(int l, int r) {
			return minQuery(1, 0, size-1, l, r);
		private long minQuery(int index, int l, int r, int left, int right) {
			if(r < left || l > right) return Long.MAX_VALUE;
			if(l >= left && r <= right) {
				return evaluateMin(index, l, r);
			pushDown(index, l, r);
			int m = (l+r)/2;
			long ret = Long.MAX_VALUE;
			ret = Math.min(ret, minQuery(2*index, l, m, left, right));
			ret = Math.min(ret, minQuery(2*index+1, m+1, r, left, right));
			pullUp(index, l, r);
			return ret;
		public long sumQuery(int l, int r) {
			return sumQuery(1, 0, size-1, l, r);
		private long sumQuery(int index, int l, int r, int left, int right) {
			if(r < left || l > right) return 0;
			if(l >= left && r <= right) {
				return evaluateSum(index, l, r);
			pushDown(index, l, r);
			int m = (l+r)/2;
			long ret = 0;
			ret += sumQuery(2*index, l, m, left, right);
			ret += sumQuery(2*index+1, m+1, r, left, right);
			pullUp(index, l, r);
			return ret;