Entity-Based Reinforcement Learning

Deep reinforcement learning is a powerful technique for creating effective decision-making systems, but its complexity has hindered widespread adoption. Despite the perceived cost of RL, a wide range of interesting applications are already feasible with current techniques. The main barrier to broader use of RL is now the lack of accessible tooling and infrastructure. In this blog post, I introduce the Entity Neural Network (ENN) project, a set of libraries designed to simplify the process of applying RL to complex simulated environments, significantly reducing the required engineering effort, computational cost, and expertise. The enn-trainer RL framework can be seamlessly integrated with complex simulators without requiring users to write custom training code or network architectures. Furthermore, the framework can train RL policies at a rate of more than 100,000 samples/s on a single consumer GPU, achieves similar performance to the IMPALA network architecture while using 50x fewer parameters and 30x-600x less FLOPS, and produces RL agents that are efficient enough to be run in real-time on the CPU inside web browsers.

Introduction

Reinforcement learning (RL) is a subfield of machine learning in which an agent learns to make decisions by interacting with its environment and receiving feedback in the form of rewards or penalties. By optimizing actions to maximize cumulative rewards, RL offers a versatile framework for solving various discrete optimization problems without extensive domain knowledge. Applications range from video game AIs​1–4​ and circuit design​5​ to data center cooling​6​, personalized suggestions​7​, aligning language models with human preferences​8​, and even discovering novel algorithms​9​. Although RL’s data requirements can be a limiting factor, its viability increases when simulators can generate affordable, large-scale data, especially in domains like video games. However, the expertise and effort needed to implement RL remain obstacles to broader experimentation and unlocking its full potential.

To exemplify the difficulty of RL, consider how one might train an AI for a simple real-time strategy game​10​ where two players each control an army of units that gather resources, create new units, and fight against the units of the other player.

Figure 1: Scene from a simple strategy game that shows three units from the orange player, one unit from the blue player, and two resource nodes.

This task can be straightforwardly formulated as an RL problem where the RL agent observes a collection of game objects, returns an action for every game unit under its control, and is rewarded for building new units and winning the game. The observations and actions for the scene in Figure 1 might be expressed as follows:

observation = {
    "allies": [
        {
            "x": 213.0, "y": 1223.0, "health": 1, "attack": 1,
            "production": 0, "storage": 0, "resources": 0,
        },
        {
            "x": 897.0, "y": 304.0, "health": 12, "attack": 3,
            "production": 3, "storage": 3, "resources": 4
        },
        {
            "x": 194.0, "y": 72.0, "health": 1, "attack": 0,
            "production": 0, "storage": 1, "resources": 5
        },
    ],
    "mineral-crystals": [
        { "x": 977.0, "y": 289.0, "size": 4012 },
        { "x": 201.0, "y": 236.0, "size": 3 },
    ],
    "enemies": [
        {
            "x": 158.0, "y": 1321.0, "health": 1, "attack": 1,
            "production": 0, "storage": 0, "resources": 0
        },
    ],
}

actions = ["fire", "move_forward", "collect"]

In theory, a thin integration layer connecting this game representation to an RL framework should be sufficient. However, attempting this reveals several obstacles. One significant challenge is that existing APIs for RL environments mandate fixed-size vectors with consistent shapes for observations at every time step. This rigid interface is ill-suited for games or complex simulations featuring varying numbers of distinct objects. Consequently, we face two difficult choices: extensively modifying existing RL frameworks (or even creating our own), or using image-based representation or alternative methods to compress observations into fixed-size vectors. Developing an RL framework demands considerable engineering effort, ML expertise, and tears1I once made the mistake of implementing the proximal policy optimization algorithm from scratch. This was so deeply traumatizing that it compelled me to write a 3000 word poem about the experience., while image-based observations often prove too computationally inefficient for most interesting applications.

Entity Gym and friends

The limited expressiveness in the observation and action spaces of existing RL interfaces is the primary motivation for the entity-neural-network project. This project has developed a set of libraries that bring RL to entity-based environments, allowing for more flexible and efficient interactions:

  • Entity Gym (filling a similar role as OpenAI Gym) defines an entity-based abstraction for RL environments that allows observations to be composed of a variable number of different objects.
  • enn-trainer provides an RL implementation that can train agents in Entity Gym environments.
  • RogueNet implements a transformer-based neural network architecture capable of efficiently processing variable-sized collections of objects and controlling agents in enn-trainer.

The Entity Gym library enables seamless integration of complex simulators and RL frameworks by providing first-class support for variable-size observations. Environments implementing the Entity Gym interface can directly return observations as a dictionary containing lists of different entities, rather than a single vector (full example here):

Observation(
    entities={
        "Drone": [(d.x, d.y, d.health, *d.modules) for d in self.drones],
        "MineralCrystals": [(m.x, m.y, m.size) for m in self.minerals],
    },
    ...
)

The enn-trainer library is an Entity Gym compatible training framework that implements the proximal policy optimization algorithm, a popular RL algorithm for optimizing agent’s policies. It is powered by RogueNet, a ragged-batch transformer​11​ implementation that recognizes the set structure of entity-based observations and can efficiently process batches of variable-sized observations. On the Procgen benchmark​12​, RogueNet performs similarly to the vision-based IMPALA network architecture​13​ while using roughly 50x fewer parameters, 30x-600x less FLOPS, and learning more consistently (details in appendices A and B).

Figure 2: Learning curves comparing the performance of the IMPALA (orange) and RogueNet (blue) neural network architectures on the 16 Procgen environments.

Blazingly fast simulators and inference

Fully utilizing the level of throughput that can be sustained by enn-trainer requires environments with more efficiency than can be easily achieved in Python. The entity-gym-rs crate provides bindings that allow simulations implemented in Rust to be exported as entity-gym environments, allowing for training speeds of more than 100,000 frames per second on a single RTX 2080 Ti and consumer CPU. By utilizing Rust derive macros, entity-gym-rs is able to offer a highly ergonomic API that supports normal Rust data structures as inputs and outputs to the neural network agents:

#[derive(Action, Debug)]
enum Move { Up, Down, Left, Right }

#[derive(Featurizable)]
struct Player { x: i32, y: i32, health: u32 }

#[derive(Featurizable)]
struct Enemy { x: i32, y: i32 }

let mut agent = Agent::random();
let obs = Obs::new(0.0)
    .entities([Player { x: 0, y: 0, health: 10 }])
    .entities([
        Enemy { x: 4, y: 0 },
        Enemy { x: 10, y: 42 },
    ]);
let action = agent.act::<Move>(obs);

While training agents for entity-gym-rs environments leverages the Python machine learning ecosystems, the trained agents can be loaded by Rust programs and run without any external dependencies. This is made possible by the rogue-net crate, a pure Rust implementation of the RogueNet network architecture. RogueNet agents are fast enough to run on the CPU and can be compiled to WebAssembly and executed inside browsers. The Bevy Starfighter demo puts all these pieces together to train neural network opponents for a simple 2D space shooter game. Click here to try it out in your browser.

Figure 3: Player and opponents for simple space shooter game controlled in real-time by RogueNet Rust.

Technical deep dive

The high level of flexibility offered by the Entity Gym API poses a number of difficulties for implementation of the RL training machinery. This section describes the unique parts of enn-trainer and RogueNet that allow it to efficiently support Entity Gym environments.

Background

The object of reinforcement learning is to train a policy that interacts with an environment in a way that maximizes a given reward signal. A policy is a function $\pi(Observation) \to \text{Map}[Action, Float]$ that takes an observation of the environment state as an input, and returns a distribution that assigns a probability to each possible action.

Similarly, an environment is a stateful function $E(Action) \to Observation$ that takes an action as an input, advances the internal environment state, and returns an observation of the current state and a reward value.

Most reinforcement learning algorithms proceed in two phases: rollout and optimization. During the rollout phase, we alternate between applying an action to the environment to obtain a new observation and feeding the observation into the policy to obtain a distribution over actions and sample the next action.

For efficiency reasons, we usually run many environment instances in parallel.

The actions, observations, and rewards produced during the rollout are collected in a sample buffer.

During the optimization phase, we run an algorithm that uses the samples collected during the rollout to adjust the policy in a way that increases the probability of taking actions that lead to a high reward.

Ragged sample buffer

In traditional RL training frameworks that can assume observations to be fixed-size vectors, sample buffers can be implemented as a statically allocated tensor with a shape of [#timesteps, #environments, #features]. However, Entity Gym observations can contain entities of varying number. They would require a tensor of shape [#timesteps, #environments, #entities, #features], where #entities can be different on every time step. This prevents us from using libraries like NumPy or PyTorch, which only support fixed shapes. To efficiently buffer samples, enn-trainer uses a Rust implementation of 3D ragged arrays which support variable sequence lengths in the second dimension. This allows the sample buffer to contain an arbitrary number of entities at every time step.

Internally, the ragged buffer datatype stores the data in a contiguous backing array, with some additional metadata to keep track of the number of items in each sequence. This is sufficient to efficiently support all the indexing/slicing/shuffling/concatenation/… operations that we require from our sample buffer.

RogueNet

Most importantly, we require a neural network architecture that can deal with structured observations that contain multiple entities with different numbers of features. The basic idea is to project all entities into a unified embedding space where entities of all types are represented by vectors with the same number of dimensions and can be combined into a single sequence. This representation can be sent through multiple layers of a (mostly) standard transformer, with additional provisions to handle variable-length sequences. Finally, we can select a subset of entities to send to action heads that output probability distributions over the specified action spaces.

Embedding

The input to the network is a dictionary which maps each entity type e to a ragged array of shape [T, *N, De], where T ranges over all environments and time steps, *N is the number of entities on a particular time step, and De is the number of features of entity type e. For each entity type, RogueNet has an embedding layer that flattens the ragged sequences for each entity type into a 2D-vector that combines the first two dimensions, normalizes all features by their running mean and variance, and projects them into a space of dimension d_model by feeding them through a linear layer, RelU, and LayerNorm. This gives us tensors of shape [Be, d_model] for every entity type, which can be concatenated into a single [B, d_model] sized tensor, where B is the total number of entities across all environments and time steps and d_model is the hidden dimension of the network. To allow us to later unpack this flattened sequence again, the embedding layer also constructs several metadata tensors that keep track of where each element in the flattened sequence came from.

Ragged attention

The multi-head attention operation allows the transformer to move information between individual elements within each sequence. In conventional transformer implementations, the activations are of shape [B, S, D], where B is the batch size, S the (fixed) sequence length, and D the model dimension. However, in our case, the sequences do not all have the same length, and we have concatenated/flattened all observations into a single sequence with tensor shape [B, D]. For two reasons we cannot directly attend over this entire sequence. First, the time complexity of attention is quadratic in the sequence length, which would make attention over the entire flattened sequence prohibitively expensive. Second, we have to disable attention across the different sub-sequences to prevent the policy from accessing information from future time steps or the observations of its opponent. We solve the first problem by unpacking the flattened sequence into a 3D tensor that pads all sequences to the length of the longest subsequence and packs smaller sequences together to reduce the amount of padding. We solve the second problem by constructing a mask that zeroes out attention between entities that come from different time steps or environments.

Feedforward

Since the feedforward part of the network is applied independently to each entity, no changes are necessary compared to a normal transformer implementation.

Multi-entity action heads

For the most part, categorical actions work exactly the same in Entity Gym as in conventional RL frameworks. One additional feature offered by Entity Gym is that it allows actions to be performed by any subset of entities rather than always outputting a single action on each time step. To support this, we select the subset of entities that are performing an action at a particular time step before sending them through each action head.

Select-entity action heads

Entity Gym also support “SelectEntity” actions, which allows the agent to select a particular entity. Some examples for when this kind of action is useful are selecting what card to play from the current hand, abilities that can target another unit, or choosing an item to equip. This kind of action is special because the set of choices is not fixed and instead depends on the number and properties of the selectable entities. The way we implement select-entity actions is that we project a “query” from the embedding of each actor and a “key” from the embedding of each actee (selectable entity). The dot product between a query and key gives us a score that determines how much an actor wants to select a particular actee. These scores are then turned into a normalized probability distribution by applying a softmax to the scores of each actor.

Conclusion

In this blog post, we introduced a collection of libraries, including Entity Gym, RogueNet, and enn-trainer, designed to bring reinforcement learning to entity-based environments. We discussed the challenges faced when integrating complex simulators with existing RL frameworks and presented our solutions, including the RogueNet neural network architecture and the use of Rust for high-performance simulations.

While we have made some progress in addressing the limitations of existing RL frameworks and providing support for entity-based environments, we believe there is still much to be explored in this area. The development of Entity Gym opens the door for others to build upon our work and investigate new applications in various domains, such as video games, complex simulations, and real-time decision-making.

Our hope is that the research community finds value in Entity Gym and that it serves as a starting point for further exploration and innovation in the realm of reinforcement learning with entity-based environments. By working together, we can continue to push the boundaries of what is possible in RL and help shape the future of intelligent systems.

Acknowledgements

Thanks to Benedikt Winter, Costa Huang, and Joseph Suárez for reviewing drafts of this blog post. The entity neural networks project was built in collaboration with Chris Bamford, Costa Huang, Théo Matricon, and Anssi Kanervisto.

Appendix A – Procgen Experiments

The Procgen Benchmark is a set of 16 procedurally-generated environments modeled after classic Atari games. The IMPALA experiments were run with PPO using CleanRL at 42d21bd (W&B project). The RogueNet experiments were run with enn-trainer/enn-zoo at ebdff69 (W&B report). The enn-trainer PPO implementation is derived from CleanRL and should be very comparable. All experiments use the same hyperparameter values, except for learning rate and entropy loss, which were tuned separately for RogueNet.

The IMPALA experiments use the standard OpenAI Gym interface with image-based observation of shape (64, 64, 3). To expose an Entity Gym interface for Procgen, I use the get_state method which returns a serialized representation of the full state of a Procgen game intended for saving/restoring environment state. Internally, Procgen represents all entities with the same Entity class. The serialized game state contains a list of these which can be reconstructed and turned into an Entity Gym observation. Procgen entities have a type property, and entities with different types are treated as separate entities in the Entity Gym observation space. Procgen also has additional environment-specific state, which is exposed to Entity Gym as global features. Many Procgen environments also contain a tilemap which is a dense 2D grid that determines the type of terrain/object at every location in the game. Entity Gym doesn’t currently have native support for this kind of structure, so the integration converts the tilemap into multiple entities by subdividing the tilemap it into larger 5×5 chunks, one-hot encoding the tile type at each of the 25 positions, and adding two additional features for the x and y location of the chunk. In some Procgen environments, a portion of entities can be off-screen and cannot be observed in image-based observations. While this should not affect the optimal policy in most cases, it could give the RogueNet policies a small edge. Further details can be found by consulting the code.

Full set of hyperparameters:

hyperparameterRogueNetIMPALA
optim.lr/learning_rate0.010.0005
optim.anneal_lr/anneal_lrtruefalse
optim.bs/batch_size20482048
total_timesteps25,000,00025,000,000
rollout.num_envs/num_envs6464
rollout.steps/steps256256
ppo.ent_coef/ent_coef0.050.01
ppo.vf_coef/vf_coef0.50.5
ppo.clip_coef/clip_coef0.20.2
ppo.gamma/gamma0.9990.999
ppo.gae_lambda/gae_lambda0.950.95
ppo.norm_adv/norm_advtruetrue
ppo.clip_vloss/clip_vlosstruetrue
ppo.anneal_entropytruefalse
net.d_model16
net.n_layer2
net.n_head2
channels[16, 32, 32]

Notebook used to create Figure 2.

Appendix B – Parameters and FLOPs

This section derives equations for the approximate number of parameters and floating point multiply-add operations performed by the RogueNet and IMPALA architectures. For simplicity, the contributions of activation functions, biases, and other smaller network components are ignored.

RogueNet

Let $D$ be the hidden dimension of the model (d_model), $L$ the number of layers, $S_i$ the number of entities of type $i$, $S = \sum_i S_i$ the sequence length, and $F_i$ the number of features of entity type $i$.

The embedding has a total of $D \sum_i F_i$ parameters for the matrices projecting the features of each entity into a space of dimension $D$, which takes $ D \sum_i F_i S_i$ FLOPS.

The MLP consists of two matrices of shapes $D \times 4D$ and $4D \times D$ for a total of $8D^2$ parameters and $8D^2 S$ FLOPS.

The attention layer has a total of $4D^2$ parameters and $4D^2S$ FLOPS in the key/value/query/output projections. The self-attention operation itself multiplies a $S \times D$ and a $D \times S$ matrix for an additional $D S^2$ FLOPS.

The total number of parameters and compute of a RogueNet is given by:

$$\text{params} =12LD^2 + D \sum\limits_i F_i$$

$$\text{compute} =12LD^2S + LDS^2 + D \sum\limits_i F_i S_i$$

The values of $F_i$ and $S_i$ are environment dependent. In a typical procgen environment, $S=10$ and there 5 different entity types with an average of 50 features. For $D=16$ and $L=2$ this yields around 10K parameters and 100K FLOPS (60x less parameters and 430x less compute than IMPALA). At the extremes, the bossfight environment averages 80 entities yielding 14K parameters and 1.1M FLOPS (43x less parameters and 29x less compute than IMPALA) and the miner environment has 6 entities yielding 10K parameters and 49K FLOPS (62x less parameters and 629x less compute than IMPALA).

Detailed calculations can be found in flops.ipynb.

For statistics on entity counts in Procgen environments, see this W&B report.

IMPALA

The input and activation in IMPALA are images with a height $H$, width $W$, and $C$ channels.

At the lowest level, IMPALA is made up of ResidualBlocks with two 3×3 convolutions for a total of $18C^2$ parameters and $18C^2HW$ FLOPS.

One level up, the ConvSequence applies a 3×3 convolution that projects the number of channels from $C$ to $C’$, a max pooling layer that halves the height and width of the network, followed by two ResidualBlocks. Summing it all up, we get a total of $9CC’ + 36’^2$ parameters and $9CC’HW + 36C’^2H’W’$ FLOPS, where $H’ = (H + 1) // 2$ and $W’ = (W + 1) // 2$.

In the final layer, we flatten the image and project it into a vector of dimension 256, which takes $256CHW$ parameters and FLOPS (this layer accounts for 83% of parameters but < 2% of FLOPS).

The version of IMPALA used in Procgen takes images of shape 3x64x64 as input and applies 3 ConvSequences that output 16, 32, and 32, channels, followed by the final projection. The total number of parameters is given by:

$$\text{params} = 9,648 + 41,472 + 46,080 + 524,288 \approx 621\text{K} $$

$$\text{flops} = 11,206,656 + 14,155,776 + 4,718,592 + 524,288 \approx 31\text{M} $$

  1. 1.
    Berner C, Brockman G, Chan B, et al. Dota 2 with large scale deep reinforcement learning. arXiv preprint arXiv:191206680. Published online 2019.
  2. 2.
    Silver D, Hubert T, Schrittwieser J, et al. A general reinforcement learning algorithm that masters chess, shogi, and Go through self-play. Science. 2018;362:1140–1144.
  3. 3.
    Wurman PR, Barrett S, Kawamoto K, et al. Outracing champion Gran Turismo drivers with deep reinforcement learning. Nature. 2022;602:223–228.
  4. 4.
    Vinyals O, Babuschkin I, Czarnecki WM, et al. Grandmaster level in StarCraft II using multi-agent reinforcement learning. Nature. 2019;575:350–354.
  5. 5.
    Roy R, Raiman J, Kant N, et al. PrefixRL: Optimization of Parallel Prefix Circuits using Deep Reinforcement Learning. In: 2021 58th ACM/IEEE Design Automation Conference (DAC). ; 2021:853-858. doi:10.1109/DAC18074.2021.9586094
  6. 6.
    Gamble C, Gao J. Safety-first AI for Autonomous Data Centre Cooling and Industrial Control. RSS. Published online 2018. https://www.deepmind.com/blog/safety-first-ai-for-autonomous-data-centre-cooling-and-industrial-control
  7. 7.
    Gauci J, Conti E, Liang Y, et al. Horizon: Facebook’s Open Source Applied Reinforcement Learning Platform. Published online 2018. doi:10.48550/ARXIV.1811.00260
  8. 8.
    Ouyang L, Wu J, Jiang X, et al. Training language models to follow instructions with human feedback. Published online 2022. doi:10.48550/ARXIV.2203.02155
  9. 9.
    Fawzi A, Balog M, Huang A, et al. Discovering faster matrix multiplication algorithms with reinforcement learning. Nature. 2022;610:47–53. doi:10.1038/s41586-022-05172-4
  10. 10.
    Winter C. Mastering Real-Time Strategy Games with Deep Reinforcement Learning: Mere Mortal Edition. Clemens’ Blog. Published online 2021. https://clemenswinter.com/2021/03/24/mastering-real-time-strategy-games-with-deep-reinforcement-learning-mere-mortal-edition/
  11. 11.
    Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need. arXiv preprint arXiv:170603762. Published online 2017.
  12. 12.
    Cobbe K, Hesse C, Hilton J, Schulman J. Leveraging procedural generation to benchmark reinforcement learning. In: International Conference on Machine Learning. PMLR; 2020:2048–2056.
  13. 13.
    Espeholt L, Soyer H, Munos R, et al. Impala: Scalable distributed deep-rl with importance weighted actor-learner architectures. In: International Conference on Machine Learning. PMLR; 2018:1407–1416.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.