Create Custom Reinforcement Learning Agents

To implement your own custom reinforcement learning algorithms, you can create a custom agent by creating a subclass of a custom agent class. You can then train and simulate this agent in MATLAB®and Simulink®environments. For more information about creating classes in MATLAB, seeUser-Defined Classes.

Create Template Class

To define your custom agent, first create a class that is a subclass of therl.agent.CustomAgentclass. As an example, this topic describes the custom LQR agent trained inTrain Custom LQR Agent. As a starting point for your own agent, you can open and modify this custom agent class. To add the example files to the MATLAB path and open the file, at the MATLAB command line, type the following code.

addpath(fullfile(matlabroot,'examples','rl','main')); editLQRCustomAgent.m

After saving the class to your own working folder, you can remove the example files from the path.

rmpath (fullfile (matlabroot,'examples','rl','main'));

This class has the following class definition, which indicates the agent class name and the associated abstract agent.

classdefLQRCustomAgent < rl.agent.CustomAgent

To define your agent, you must specify the following:

  Agent properties

  • Constructor function

  • A critic that estimates the discounted long-term reward (if required for learning)

  • An actor that selects an action based on the current observation (if required for learning)

  • Required agent methods

  • Optional agent methods

Agent Properties

In thepropertiessection of the class file, specify any parameters necessary for creating and training the agent. These parameters can include:

  • Discount factor for discounting future rewards

  • Configuration parameters for exploration models, such as noise models or epsilon-greedy exploration

  • Experience buffers for using replay memory

  • Mini-batch sizes for sampling from the experience buffer

  • Number of steps to look ahead during training

For more information on potential agent properties, see the option objects for the built-in Reinforcement Learning Toolbox™ agents.

Therl.Agent.CustomAgentclass already includes properties for the agent sample time (SampleTime) and the action and observation specifications (ActionInfoandObservationInfo, respectively).

The custom LQR agent defines the following agent properties.

properties% QQ% RR% Feedback gainK% Discount factorGamma = 0.95% CriticCritic% Buffer for KKBuffer% Number of updates for KKUpdate = 1% Number for estimator updateEstimateNum = 10endproperties (Access = private) Counter = 1 YBuffer HBufferend

Constructor Function

To create your custom agent, you must define a constructor function that:

  • Defines the action and observation specifications. For more information about creating these specifications, seerlNumericSpecandrlFiniteSetSpec.

  • Creates actor and critic as required by your training algorithm. For more information, seeCreate Policies and Value Functions.

  • Configures agent properties.

  • Calls the constructor of the base abstract class.

For example, theLQRCustomAgentconstructor defines continuous action and observation spaces and creates a critic. ThecreateCriticfunction is an optional helper function that defines the critic.

functionobj = LQRCustomAgent(Q,R,InitialK)% Check the number of input argumentsnarginchk(3,3);% Call the abstract class constructorobj = obj@rl.agent.CustomAgent();% Set the Q and R matricesobj.Q = Q; obj.R = R;% Define the observation and action spacesobj.ObservationInfo = rlNumericSpec([size(Q,1),1]); obj.ActionInfo = rlNumericSpec([size(R,1),1]);%创建评论家obj.Critic = createCritic(obj);% Initialize the gain matrixobj.K = InitialK;% Initialize the experience buffersobj.YBuffer = zeros(obj.EstimateNum,1); num = size(Q,1) + size(R,1); obj.HBuffer = zeros(obj.EstimateNum,0.5*num*(num+1)); obj.KBuffer = cell(1,1000); obj.KBuffer{1} = obj.K;end

Actor and Critic

If your learning algorithm uses a critic to estimate the long-term reward, an actor for selecting an action, or both, you must add these as agent properties. You must then create these objects when you create your agent; that is, in the constructor function. For more information on creating actors and critics, seeCreate Policies and Value Functions.

For example, the custom LQR agent uses a critic, stored in itsCriticproperty, and no actor. The critic creation is implemented in thecreateCritichelper function, which is called from theLQRCustomAgentconstructor.

functioncritic = createCritic(obj) nQ = size(obj.Q,1); nR = size(obj.R,1); n = nQ+nR; w0 = 0.1*ones(0.5*(n+1)*n,1); critic = rlQValueFunction({@(x,u) computeQuadraticBasis(x,u,n),w0},...getObservationInfo(obj),getActionInfo(obj)); critic.Options.GradientThreshold = 1;end

In this case, the critic is anrlQValueFunctionobject. To create this object, you must specify the handle to a custom basis function, in this case thecomputeQuadraticBasisfunction. For more information, seeTrain Custom LQR Agent.

Required Functions

To create a custom reinforcement learning agent you must define the following implementation functions. To call these functions in your own code, use the wrapper methods from the abstract base class. For example, to callgetActionImpl, usegetAction. The wrapper methods have the same input and output arguments as the implementation methods.

Function Description
getActionImpl Selects an action by evaluating the agent policy for a given observation
getActionWithExplorationImpl Selects an action using the exploration model of the agent
learnImpl Learns from the current experiences and returns an action with exploration

Within your implementation functions, to evaluate your actor and critic, you can use thegetValue,getAction, andgetMaxQValuefunctions.

  • To evaluate anrlValueFunctioncritic, you need only the observation input, and you can obtain the value of the current observationVusing the following syntax.

    V = getValue(Critic,Observation);
  • To evaluate anrlQValueFunctioncritic you need both observation and action inputs, and you can obtain the value of the current state-actionQusing the following syntax.

    Q = getValue(Critic,[Observation,Action]);
  • To evaluate anrlVectorQValueFunctioncritic you need only the observation input, and you can obtain the value of the current observationQfor all possible discrete actions using the following syntax.

    Q = getValue(Critic,Observation);
  • For a discrete action spacerlQValueFunctioncritic, obtain the maximum Q state-action value functionQfor all possible discrete actions using the following syntax.

    [MaxQ,MaxActionIndex] = getMaxQValue(Critic,Observation);
  To evaluate an actor, obtain the action using the following syntax.
    A = getAction(Actor,Observation);

    A = getAction(Actor,Observation);

For each of these cases, if your actor or critic network uses a recurrent neural network, the functions can also return the current values of the network state after obtaining the corresponding network output.


ThegetActionImplfunction is evaluates the policy of your agent and selects an action. This function must have the following signature, whereobjis the agent object,Observationis the current observation, andactionis the selected action.

functionaction = getActionImpl(obj,Observation)

For the custom LQR agent, you select an action by applying theu=-Kxcontrol law.

functionaction = getActionImpl(obj,Observation)% Given the current state of the system, return an actionaction = -obj.K*Observation{:};end


ThegetActionWithExplorationImplfunction selects an action using the exploration model of your agent. Using this function you can implement algorithms such as epsilon-greedy exploration. This function must have the following signature, whereobjis the agent object,Observationis the current observation, andactionis the selected action.

functionaction = getActionWithExplorationImpl(obj,Observation)

For the custom LQR agent, thegetActionWithExplorationImplfunction adds random white noise to an action selected using the current agent policy.

functionaction = getActionWithExplorationImpl(obj,Observation)% Given the current observation, select an actionaction = getAction(obj,Observation);% Add random noise to the actionnum = size(obj.R,1); action = action + 0.1*randn(num,1);end


ThelearnImplfunction defines how the agent learns from the current experience. This function implements the custom learning algorithm of your agent by updating the policy parameters and selecting an action with exploration. This function must have the following signature, whereobjis the agent object,expis the current agent experience, andactionis the selected action.

functionaction = learnImpl(obj,exp)

The agent experience is the cell arrayexp = {state,action,reward,nextstate,isdone}.

  • stateis the current observation.

  • actionis the current action.

  • rewardis the current reward.

  • nextStateis the next observation.

  • isDoneis a logical flag indicating that the training episode is complete.

For the custom LQR agent, the critic parameters are updated everyNsteps.

functionaction = learnImpl(obj,exp)%解析输入的经验x = exp{1}{1}; u = exp{2}{1}; dx = exp{4}{1}; y = (x'*obj.Q*x + u'*obj.R*u); num = size(obj.Q,1) + size(obj.R,1);% Wait N steps before updating the critic parametersN = obj.EstimateNum; h1 = computeQuadraticBasis(x,u,num); h2 = computeQuadraticBasis(dx,-obj.K*dx,num); H = h1 - obj.Gamma* h2;ifobj.Counter<=N obj.YBuffer(obj.Counter) = y; obj.HBuffer(obj.Counter,:) = H; obj.Counter = obj.Counter + 1;else% Update the critic parameters based on the batch of% experiencesH_buf = obj.HBuffer; y_buf = obj.YBuffer; theta = (H_buf'*H_buf)\H_buf'*y_buf; obj.Critic = setLearnableParameters(obj.Critic,{theta});%获得一个新的增益矩阵的基础上,新批评parametersobj.K = getNewK(obj);% Reset the experience buffersobj.Counter = 1; obj.YBuffer = zeros(N,1); obj.HBuffer = zeros(N,0.5*num*(num+1)); obj.KUpdate = obj.KUpdate + 1; obj.KBuffer{obj.KUpdate} = obj.K;end% Find and return an action with explorationaction = getActionWithExploration(obj,exp{4});end

Optional Functions

Optionally, you can define how your agent is reset at the start of training by specifying aresetImplfunction with the following function signature, whereobjis the agent object. Using this function, you can set the agent into a known or random condition before training.


Also, you can define any other helper functions in your custom agent class as required. For example, the custom LQR agent defines acreateCriticfunction for creating the critic and agetNewKfunction that derives the feedback gain matrix from the trained critic parameters.

Create Custom Agent

After you define your custom agent class, create an instance of it in the MATLAB workspace. For example, to create the custom LQR agent, define theQ,R, andInitialKvalues and call the constructor function.

Q = [10,3,1;3,5,4;1,4,9]; R = 0.5*eye(3); K0 = place(A,B,[0.4,0.8,0.5]); agent = LQRCustomAgent(Q,R,K0);

After validating the environment object, you can use it to train a reinforcement learning agent. For an example that trains the custom LQR agent, seeTrain Custom LQR Agent.

