Main Content

Train DQN Agent to Balance Cart-Pole System

This example shows how to train a deep Q-learning network (DQN) agent to balance a cart-pole system modeled in MATLAB®.

For more information on DQN agents, seeDeep Q-Network Agents. For an example that trains a DQN agent in Simulink®, seeTrain DQN Agent to Swing Up and Balance Pendulum.

Cart-Pole MATLAB Environment

The reinforcement learning environment for this example is a pole attached to an unactuated joint on a cart, which moves along a frictionless track. The training goal is to make the pole stand upright without falling over.

For this environment:

  • The upward balanced pole position is0radians, and the downward hanging position ispiradians.

  • The pole starts upright with an initial angle between –0.05 and 0.05 radians.

  • The force action signal from the agent to the environment is from –10 to 10 N.

  • The observations from the environment are the position and velocity of the cart, the pole angle, and the pole angle derivative.

  • The episode terminates if the pole is more than 12 degrees from vertical or if the cart moves more than 2.4 m from the original position.

  • A reward of +1 is provided for every time step that the pole remains upright. A penalty of –5 is applied when the pole falls.

For more information on this model, seeLoad Predefined Control System Environments.

Create Environment Interface

Create a predefined environment interface for the system.

env = rlPredefinedEnv("CartPole-Discrete")
env = CartPoleDiscreteAction with properties: Gravity: 9.8000 MassCart: 1 MassPole: 0.1000 Length: 0.5000 MaxForce: 10 Ts: 0.0200 ThetaThresholdRadians: 0.2094 XThreshold: 2.4000 RewardForNotFalling: 1 PenaltyForFalling: -5 State: [4x1 double]

The interface has a discrete action space where the agent can apply one of two possible force values to the cart, –10 or 10 N.

Get the observation and action specification information.

obsInfo = getObservationInfo(env)
obsInfo = rlNumericSpec with properties: LowerLimit: -Inf UpperLimit: Inf Name: "CartPole States" Description: "x, dx, theta, dtheta" Dimension: [4 1] DataType: "double"
actInfo = getActionInfo(env)
actInfo = rlFiniteSetSpec with properties: Elements: [-10 10] Name: "CartPole Action" Description: [0x0 string] Dimension: [1 1] DataType: "double"

Fix the random generator seed for reproducibility.

rng(0)

Create DQN Agent

A DQN agent approximates the long-term reward, given observations and actions, using a value-function critic.

DQN agents can use multi-output Q-value critic approximators, which are generally more efficient. A multi-output approximator has observations as inputs and state-action values as outputs. Each output element represents the expected cumulative long-term reward for taking the corresponding discrete action from the state indicated by the observation inputs.

To create the critic, first create a deep neural network with one input (the 4-dimensional observed state) and one output vector with two elements (one for the 10 N action, another for the –10 N action). For more information on creating value-function representations based on a neural network, seeCreate Policy and Value Function Representations.

dnn = [ featureInputLayer(obsInfo.Dimension(1),'Normalization','none','Name','state') fullyConnectedLayer(24,'Name','CriticStateFC1') reluLayer('Name','CriticRelu1') fullyConnectedLayer(24,'Name','CriticStateFC2') reluLayer('Name','CriticCommonRelu') fullyConnectedLayer(length(actInfo.Elements),'Name','output')];

View the network configuration.

figure plot(layerGraph(dnn))

Figure contains an axes object. The axes object contains an object of type graphplot.

Specify some training options for the critic representation usingrlRepresentationOptions.

criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);

Create the critic representation using the specified neural network and options. For more information, seerlQValueRepresentation.

critic = rlQValueRepresentation(dnn,obsInfo,actInfo,'Observation',{'state'},criticOpts);

To create the DQN agent, first specify the DQN agent options usingrlDQNAgentOptions.

agentOpts = rlDQNAgentOptions(...'UseDoubleDQN',false,...'TargetSmoothFactor',1,...'TargetUpdateFrequency',4,...'ExperienceBufferLength',100000,...'DiscountFactor',0.99,...'MiniBatchSize',256);

Then, create the DQN agent using the specified critic representation and agent options. For more information, seerlDQNAgent.

agent = rlDQNAgent(critic,agentOpts);

Train Agent

To train the agent, first specify the training options. For this example, use the following options:

  • Run one training session containing at most 1000 episodes, with each episode lasting at most 500 time steps.

  • Display the training progress in the Episode Manager dialog box (set thePlotsoption) and disable the command line display (set theVerboseoption tofalse).

  • Stop training when the agent receives an moving average cumulative reward greater than 480. At this point, the agent can balance the cart-pole system in the upright position.

For more information, seerlTrainingOptions.

trainOpts = rlTrainingOptions(...'MaxEpisodes',1000,...'MaxStepsPerEpisode',500,...'Verbose',false,...“阴谋”,'training-progress',...'StopTrainingCriteria','AverageReward',...'StopTrainingValue',480);

You can visualize the cart-pole system can be visualized by using theplotfunction during training or simulation.

plot(env)

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

Train the agent using thetrainfunction. Training this agent is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by settingdoTrainingtofalse. To train the agent yourself, setdoTrainingtotrue.

doTraining = false;ifdoTraining% Train the agent.trainingStats = train(agent,env,trainOpts);else% Load the pretrained agent for the example.load('MATLABCartpoleDQNMulti.mat','agent')end

Simulate DQN Agent

To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, seerlSimulationOptionsandsim. The agent can balance the cart-pole even when the simulation time increases to 500 steps.

simOptions = rlSimulationOptions('MaxSteps',500); experience = sim(env,agent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

totalReward = sum(experience.Reward)
totalReward = 500

See Also

Related Topics