REINFORCE with Baseline

In my last post, I implemented REINFORCE which is a simple policy gradient algorithm. We saw that while the agent did learn, the high variance in the rewards inhibited the learning. In this post, I will discuss a technique that will help improve this.

Proof

Consider the set of numbers 500, 50, and 250. The variance of this set of numbers is about 50,833. What if we subtracted some value from each number, say 400, 30, and 200? Then the new set of numbers would be 100, 20, and 50, and the variance would be about 16,333. This is a pretty significant difference, and this idea can be applied to our policy gradient algorithms to help reduce the variance by subtracting some baseline value from the returns. But wouldn’t subtracting a random number from the returns result in incorrect, biased data? It turns out that the answer is no, and below is the proof.

Recall the policy gradient theorem:

θJ(πθ)=E[t=0Tθlogπθ(atst)t=tTγtrt]\nabla_\theta J\left(\pi_\theta\right) = \mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) \sum_{t' = t}^T \gamma^{t'} r_{t'}\right]

Suppose we subtract some value, bb, from the return that is a function of the current state, sts_t, so that we now have

θJ(πθ)=E[t=0Tθlogπθ(atst)t=tT(γtrtb(st))]=E[t=0Tθlogπθ(atst)t=tTγtrtt=0Tθlogπθ(atst)b(st)]=E[t=0Tθlogπθ(atst)t=tTγtrt]E[t=0Tθlogπθ(atst)b(st)]\begin{aligned} \nabla_\theta J\left(\pi_\theta\right) &= \mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) \sum_{t' = t}^T \left(\gamma^{t'} r_{t'} - b\left(s_t\right)\right) \right] \\ &=\mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) \sum_{t' = t}^T \gamma^{t'} r_{t'} - \sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) b\left(s_t\right) \right] \\ &= \mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) \sum_{t' = t}^T \gamma^{t'} r_{t'} \right] - \mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) b\left(s_t\right) \right] \end{aligned}

We can also expand the second expectation term as

E[t=0Tθlogπθ(atst)b(st)]=E[θlogπθ(a0s0)b(s0)+θlogπθ(a1s1)b(s1)++θlogπθ(aTsT)b(sT)]=E[θlogπθ(a0s0)b(s0)]+E[θlogπθ(a1s1)b(s1)]++E[θlogπθ(aTsT)b(sT)]\begin{aligned} \mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) b\left(s_t\right) \right] &= \mathbb{E} \left[\nabla_\theta \log \pi_\theta \left(a_0 \vert s_0 \right) b\left(s_0\right) + \nabla_\theta \log \pi_\theta \left(a_1 \vert s_1 \right) b\left(s_1\right) + \cdots + \nabla_\theta \log \pi_\theta \left(a_T \vert s_T \right) b\left(s_T\right)\right] \\ &= \mathbb{E} \left[\nabla_\theta \log \pi_\theta \left(a_0 \vert s_0 \right) b\left(s_0\right)\right] + \mathbb{E} \left[\nabla_\theta \log \pi_\theta \left(a_1 \vert s_1 \right) b\left(s_1\right)\right] + \cdots + \mathbb{E} \left[\nabla_\theta \log \pi_\theta \left(a_T \vert s_T \right) b\left(s_T\right)\right] \end{aligned}

Because the probability of each action and state occurring under the current policy does change with time, all of the expectations are the same and we can reduce the expression to

E[t=0Tθlogπθ(atst)b(st)]=(T+1)E[θlogπθ(a0s0)b(s0)]\mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) b\left(s_t\right) \right] = \left(T + 1\right) \mathbb{E} \left[\nabla_\theta \log \pi_\theta \left(a_0 \vert s_0 \right) b\left(s_0\right)\right]

I apologize in advance to all the researchers I may have disrespected with any blatantly wrong math up to this point. I am just a lowly mechanical engineer (on paper, not sure what I am in practice). Please correct me in the comments if you see any mistakes. But assuming no mistakes, we will continue. Using the definition of expectation, we can rewrite the expectation term on the RHS as

E[θlogπθ(a0s0)b(s0)]=sμ(s)aπθ(as)θlogπθ(as)b(s)=sμ(s)aπθ(as)θπθ(as)πθ(as)b(s)=sμ(s)b(s)aθπθ(as)=sμ(s)b(s)θaπθ(as)=sμ(s)b(s)θ1=sμ(s)b(s)(0)=0\begin{aligned} \mathbb{E} \left[\nabla_\theta \log \pi_\theta \left(a_0 \vert s_0 \right) b\left(s_0\right) \right] &= \sum_s \mu\left(s\right) \sum_a \pi_\theta \left(a \vert s\right) \nabla_\theta \log \pi_\theta \left(a \vert s \right) b\left(s\right) \\ &= \sum_s \mu\left(s\right) \sum_a \pi_\theta \left(a \vert s\right) \frac{\nabla_\theta \pi_\theta \left(a \vert s \right)}{\pi_\theta \left(a \vert s\right)} b\left(s\right) \\ &= \sum_s \mu\left(s\right) b\left(s\right) \sum_a \nabla_\theta \pi_\theta \left(a \vert s \right) \\ &= \sum_s \mu\left(s\right) b\left(s\right) \nabla_\theta \sum_a \pi_\theta \left(a \vert s \right) \\ &= \sum_s \mu\left(s\right) b\left(s\right) \nabla_\theta 1 \\ &= \sum_s \mu\left(s\right) b\left(s\right) \left(0\right) \\ &= 0 \end{aligned}

where μ(s)\mu\left(s\right) is the probability of being in state ss. Therefore,

E[t=0Tθlogπθ(atst)b(st)]=0\mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) b\left(s_t\right) \right] = 0

and

θJ(πθ)=E[t=0Tθlogπθ(atst)t=tT(γtrtb(st))]=E[t=0Tθlogπθ(atst)t=tTγtrt]\begin{aligned} \nabla_\theta J\left(\pi_\theta\right) &= \mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) \sum_{t' = t}^T \left(\gamma^{t'} r_{t'} - b\left(s_t\right)\right) \right] \\ &= \mathbb{E} \left[\sum_{t = 0}^T \nabla_\theta \log \pi_\theta \left(a_t \vert s_t \right) \sum_{t' = t}^T \gamma^{t'} r_{t'} \right] \end{aligned}

In other words, as long as the baseline value we subtract from the return is independent of the action, it has no effect on the gradient estimate! But what is b(st)b\left(s_t\right)? It can be anything, even a constant, as long as it has no dependence on the action. We will choose it to be V^(st,w)\hat{V}\left(s_t,w\right) which is the estimate of the value function at the current state. ww is the weights parametrizing V^\hat{V}. I think Sutton & Barto do a good job explaining the intuition behind this. Some states will yield higher returns, and others will yield lower returns, and the value function is a good choice of a baseline because it adjusts accordingly based on the state. A state that yields a higher return will also have a high value function estimate, so we subtract a higher baseline. Likewise, we substract a lower baseline for states with lower returns.

But we also need a way to approximate V^\hat{V}. We can update the parameters of V^\hat{V} using stochastic gradient. Once we have sample a trajectory, we will know the true returns of each state, so we can calculate the error between the true return and the estimated value function as

δ=GtV^(st,w)\delta = G_t - \hat{V} \left(s_t,w\right)

If we square this and calculate the gradient, we get

w[12(GtV^(st,w))2]=(GtV^(st,w))wV^(st,w)=δwV^(st,w)\begin{aligned} \nabla_w \left[ \frac{1}{2} \left(G_t - \hat{V} \left(s_t,w\right) \right)^2\right] &= -\left(G_t - \hat{V} \left(s_t,w\right) \right) \nabla_w \hat{V} \left(s_t,w\right) \\ &= -\delta \nabla_w \hat{V} \left(s_t,w\right) \end{aligned}

I included the 12\frac{1}{2} just to keep the math clean. We want to minimize this error, so we update the parameters using gradient descent:

w=w+δwV^(st,w)\begin{aligned} w = w +\delta \nabla_w \hat{V} \left(s_t,w\right) \end{aligned}

Now, we will implement this to help make things more concrete.

Implementation

As in my previous posts, I will test the algorithm on the discrete-cart pole environment. The REINFORCE algorithm with baseline is mostly the same as the one used in my last post with the addition of the value function estimation and baseline subtraction. In my implementation, I used a linear function approximation so that

V^(st,w)=wTst\hat{V} \left(s_t,w\right) = w^T s_t

where ww and sts_t are 4×14 \times 1 column vectors. Then,

wV^(st,w)=st\nabla_w \hat{V} \left(s_t,w\right) = s_t

and we update the parameters according to

w=w+(GtwTst)stw = w + \left(G_t - w^T s_t\right) s_t

Below is the complete code:

% REINFORCE with Baseline applied to the discrete cart-pole environment
close all; clear; clc

% Seed the random number generator for reproducibility. I will be
% initializing the weights to be zero, so it does not really matter in this
% case, but I included it for completion.
rng(0)

% Create the environment.
env = rlPredefinedEnv('CartPole-Discrete');

% Get the dimension of the observations, in this case, 4. This is not to be
% confused with the number of observations per time step, which is 1.
dimObs = env.getObservationInfo.Dimension(1);

% Get the dimension of the actions, in this case, 1. This is not to be
% confused with the number of actions to choose from, which is 2.
dimAct = env.getActionInfo.Dimension(1);

% Initialize the policy and value function parameters to zero. They can 
% also be initialized randomly from a normal distribution, or whatever else 
% you would like, but I like zero.
theta_policy = zeros(dimObs,1);
theta_valueFcn = zeros(dimObs,1);

% Define the policy. This takes a 4x1 observation vector as input and
% outputs the probability of picking the "left" action. The "right" action 
% is 1 - probability(left), so we do not need to define it explicitly.
pi_theta_left = @(theta,s) 1 / (1 + exp(-theta' * s(:)));

% Define the gradients of the log of the policy. This outputs a 2x1 cell
% array of the grad-log-probabilities of the actions.
grad_pi_theta = @(theta,s) {s * (1 - pi_theta_left(theta,s)); -s * pi_theta_left(theta,s)};

% Set the maximum number of episodes for which to train.
maxEpisodes = 1000;

% Set the maximum number of time steps per episode. When the agent gets
% good at balancing the pole, it may never fail, so we need a way to
% terminate the episode after a certain time.
maxStepsPerEpisode = 500;

% Set the learning rates.
learningRate_policy = 0.001;
learningRate_valueFcn = 0.1;

% Set the discount factor.
discountFactor = 0.99;

% Initialize a vector to store the total time steps that the agent lasts
% per episode.
totalStepsPerEpisode = zeros(1,maxStepsPerEpisode);

% Train the agent.
for episode = 1:maxEpisodes
    % Reset the environment.
    observation = env.reset;
    
    % Initialize a variable to store the trajectory of the current episode.
    trajectory = cell(1,maxStepsPerEpisode);
    
    % Simulate an episode.
    for stepCt = 1:maxStepsPerEpisode
        % Sample an action.
        action = randsample(env.getActionInfo.Elements,1,true,...
            [pi_theta_left(theta_policy,observation) 1 - pi_theta_left(theta_policy,observation)]);
        
        % Apply the action and get the next observation, reward, and
        % whether the episode is over.
        [nextObservation,reward,isDone] = step(env,action);
        
        % Store the experience in the trajectory.
        trajectory{stepCt} = {observation action reward};
        
        % Update the environment.
        observation = nextObservation;
        
        if isDone
            break
        end
    end
    
    % Update the list of total steps per episode.
    totalStepsPerEpisode(episode) = stepCt;
    
    % Initialize the gradient estimates.
    gradientEstimate_policy = 0;
    gradientEstimate_valueFcn = 0;
    
    % Calculate the gradient estimates.
    for t = 1:stepCt
        % Get the observation, action, and reward at the current time step.
        observation = trajectory{t}{1};
        action = trajectory{t}{2};
        reward = trajectory{t}{3};
        
        % Compute the discounted return.
        discountedReturn = 0;
        for k = t + 1:stepCt
            discountedReturn = discountedReturn + discountFactor^(k - t) * reward;
        end
        
        % Compute the grad-log-probabilities.
        gradLogProbs = grad_pi_theta(theta_policy,observation);
        
        % Get the grad-log-probability corresponding to the action
        % taken.
        gradLogProb = gradLogProbs{action == env.getActionInfo.Elements};
        
        % Compute delta.
        delta = discountedReturn - theta_valueFcn' * observation;
        
        % Update the gradient estimates.
        gradientEstimate_policy = gradientEstimate_policy + gradLogProb * delta;
        gradientEstimate_valueFcn = gradientEstimate_valueFcn + delta * observation;
    end
    
    % Update the policy and value function parameters.
    theta_policy = theta_policy + learningRate_policy * gradientEstimate_policy;
    theta_valueFcn = theta_valueFcn + learningRate_valueFcn * gradientEstimate_valueFcn / stepCt;
end

% Plot Total Steps per Episode.
plot(1:episode,totalStepsPerEpisode(1:episode))
xlabel('Episode')
ylabel('Total Steps')
title('Total Steps per Episode')
grid on
grid minor

Note that I update both the policy and value function parameters once per trajectory. However, the policy gradient estimate requires every time step of the trajectory to be calculated, while the value function gradient estimate requires only one time step to be calculated. As a result, I have multiple gradient estimates of the value function which I average together before updating the value function parameters. I do not think this is mandatory though. The division by stepCt could be absorbed into the learning rate.

Also note that I set the learning rate for the value function parameters to be much higher than that of the policy parameters. My intuition for this is that we want the value function to be learned faster than the policy so that the policy can be updated more accurately. But this is just speculation and with some trial and error, a lower learning rate for the value function parameters might be more effective.

Results

Below are the training results:

For comparison, here are the results without subtracting the baseline:

We can see that there is definitely an improvement in the variance when subtracting a baseline. But in terms of which training curve is actually better, I am not too sure. The unfortunate thing with reinforcement learning is that, at least in my case, even when implemented incorrectly, the algorithm may seem to work, sometimes even better than when implemented correctly. So I am not sure if the above results are accurate, or if there is some subtle mistake that I made. Please let me know in the comments if you find any bugs.

Of course, there is always room for improvement. In my next post, we will discuss how to update the policy without having to sample an entire trajectory first. This will allow us to update the policy during the episode as opposed to after which should allow for faster training.