How to create an environment for a Tensorflow Agent?

In this article, I am going to show you how to create a Tensorflow Agent environment from scratch. I will not train an agent. Instead of that, I am going to make random moves in that environment and see what happens.

The goal

I want to teach the agent how to play TicTacToe, but I want to take baby steps. The first version of the game environment allows only one player to play, and it does not check the wining conditions. It is the simplest and most basic environment. It teaches the agent to pick an empty spot. It is the rule that people understand intuitively, but it is not known to reinforcement learning agents.

My goal is to get an agent that makes a valid move: chooses an empty spot on the board. It should not try to pick an occupied spot. The game ends successfully when the agent manages to fill the whole board. The agent loses when it makes an illegal move (picks the same spot twice).

Define the environment

Note that I am using the nightly build of Tensorflow Agent which were available at the time of writing this article. Let’s start with mandatory imports.

import tensorflow as tf
import numpy as np

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

tf.compat.v1.enable_v2_behavior()

The environment consists of four elements.

Action spec & observation spec

The environment consists of four elements. First, I have to specify the available actions, and the structure of the data returned as the representation of the environment’s state.

In my case, I have a 3x3 TicTacToe board, and I want to take actions step by step, so my action specification is a single number (denoted as an empty tuple) between 0 and 8.

I want to inform the agent whether the field is taken or not. For that, it is sufficient to pass a single 9-element array of zeros and ones. Because of that, my observation spec has the shape (1, 9) which means “a single 9-element long array.”

I also need to store the board state in an array, and I need a boolean flag indicating whether the game has ended.

class TicTacToeBoardWithNoRulesEnvironment(py_environment.PyEnvironment):

  def __init__(self):
    self._action_spec = array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=8, name='action')
    self._observation_spec = array_spec.BoundedArraySpec(
        shape=(1,9), dtype=np.int32, minimum=0, maximum=1, name='observation')
    self._state = [0, 0, 0, 0, 0, 0, 0, 0, 0]
    self._episode_ended = False

  def action_spec(self):
    return self._action_spec

  def observation_spec(self):
    return self._observation_spec

Reset

The environment is mutable, so I need a function that restores the initial state. According to the PyEnvironment documentation, such a function must be called _reset, and it must return the state of the environment and information that it is the first step of the simulation. A valid response is created by calling ts.restart with the current state as an argument.

def _reset(self):
    self._state = [0, 0, 0, 0, 0, 0, 0, 0, 0]
    self._episode_ended = False
    return ts.restart(np.array([self._state], dtype=np.int32))

Board state

I also need functions which tell me whether the action taken by the agent is legal and whether the game has ended. For that, I have created two helper functions:

def __is_spot_empty(self, index):
    return self._state[index] == 0

def __all_spots_occupied(self):
    return all(item == 1 for item in self._state)

Step

Now, it is time to define the function which takes action and returns the reward. We must override the _step function defined in the PyEnvironment class. First of all, if the game has ended the step function must call the reset function.

def _step(self, action):
    if self._episode_ended:
        return self.reset()

Later, if the agent picked a valid action, I want to mark the spot as occupied and check whether the board is full. If the board is full, we must return the terminate state and the reward (for completing the board the agent gets reward = 1). If the board is not completed, we inform the agent that the environment has transitioned to a new state. I decided to give a tiny reward also for picking a valid move, so the agent gets 0.1.

The discount parameter is the discount factor, which I described in my article about the Bellman equation.

#the continuation of the _step(self, action) function
    if self.__is_spot_empty(action):
        self._state[action] = 1

        if self.__all_spots_occupied():
            self._episode_ended = True
            return ts.termination(np.array([self._state], dtype=np.int32), 1)
        else:
            return ts.transition(np.array([self._state], dtype=np.int32), reward=0.1, discount=1.0)

    else:
        self._episode_ended = True
        return ts.termination(np.array([self._state], dtype=np.int32), -1)

Validate the environment

When the environment is ready, it is best to call the validate_py_environment function. It helps to spot silly mistakes, such as uninitialized variables or incompatible data types.

environment2 = TicTacToeBoardWithNoRulesEnvironment()
utils.validate_py_environment(environment2, episodes=5)

TensorflowEnvironment

In the runtime, we want to parallelize computation, so we should convert the PyEnvironment into a TensorflowEnvironment which uses tensors instead of arrays. Fortunately, this can be done automatically.

tf_env = tf_py_environment.TFPyEnvironment(environment)

Using the environment

As I mentioned at the beginning of the article, I am not going to train an agent. Instead of that, I will make a series of random actions in the environment and see what total reward I can get.

First, I have to reset the environment and define variables that I will use to store the results of every episode. An episode is a single attempt to win the game. It begins when I call the reset function and ends when I either fill the board or make a wrong move, therefore it will consist of 2 - 9 actions.

In every episode, I continuously take actions (until the environment reaches a terminate state). In the inner loop, I randomly pick a number between 0 and 8. It is the location on the board which I am going to choose. I pass that number to the step function. As a result, I get the next state of the environment and the reward I got for taking that action.

In the end, I gather the results and print them to the console. I decided to play 10000 episodes, so I hope to win the game (fill the whole board without making an illegal move) at least once.

time_step = tf_env.reset()
rewards = []
steps = []
num_episodes = 10000

for _ in range(num_episodes):
  episode_reward = 0
  episode_steps = 0
  tf_env.reset()
  while not tf_env.current_time_step().is_last():
    action = tf.random_uniform([1], 0, 9, dtype=tf.int32)
    next_time_step = tf_env.step(action)
    episode_steps += 1
    episode_reward += next_time_step.reward.numpy()
  rewards.append(episode_reward)
  steps.append(episode_steps)

num_steps = np.sum(steps)
avg_length = np.mean(steps)
avg_reward = np.mean(rewards)
max_reward = np.max(rewards)
max_length = np.max(steps)

print('num_episodes:', num_episodes, 'num_steps:', num_steps)
print('avg_length', avg_length, 'avg_reward:', avg_reward)
print('max_length', max_length, 'max_reward:', max_reward)

Result:

num_episodes: 10000 num_steps: 44616
avg_length 4.4616 avg_reward: -0.65164006
max_length 9 max_reward: 1.8000001

Single responsibility principle

When you look at the environment implementation above, you probably notice that it violates many good practices of software development. I have the game rules interleaving with the Tensorflow-specific code, and the game logic uses Tensorflow data types directly.

Such an implementation may be fast (because there is no conversion between types), but if the game rules were complicated, it would be challenging to write them in such a way.

In such a situation, I would probably prefer to sacrifice a little bit of performance and implement the code in a more human-friendly way.

First of all, we can name the game states. It will help us understand what happened in the environment.

from enum import Enum
class ActionResult(Enum):
  VALID_MOVE = 1
  BOARD_FULL = 2
  ILLEGAL_MOVE = 3

Now, I can move the game rules to a separate class that “knows” nothing about the existence of Tensorflow. I will also replace the numeric array with boolean values because the board cells have only two possible states (occupied or not).

Note that I renamed the _step function to mark_spot and now it returns the enum representing the result of an action.

class TicTacToeBoardWithNoRulesBusinessLogic():
  def __init__(self):
    self._state = [False, False, False, False, False, False, False, False, False]
    self._game_ended = False

  def reset(self):
    self._state = [False, False, False, False, False, False, False, False, False]
    self._game_ended = False

  def __is_spot_empty(self, index):
    return not self._state[index]

  def __all_spots_occupied(self):
    return all(item == True for item in self._state)

  def mark_spot(self, spot_position):
    if spot_position < 0 or spot_position > 8:
      raise ValueError("Action must be between 0 and 8.")

    if self.__is_spot_empty(spot_position):
      self._state[spot_position] = True

      if self.__all_spots_occupied():
        self._game_ended = True
        return ActionResult.BOARD_FULL
      else:
        return ActionResult.VALID_MOVE

    else:
      self._game_ended = True
      return ActionResult.ILLEGAL_MOVE

  def game_ended(self):
    return self._game_ended

  def board_state(self):
    return self._state

Finally, I can extend the PyEnvironment class, but this time, it contains only the Tensorflow-specific code. It works as an adapter between Tensorflow agent and the business rules of the game.

class TicTacToeBoardWithNoRulesEnvironment(py_environment.PyEnvironment):

  def __init__(self, game):
    self._action_spec = array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=8, name='action')
    self._observation_spec = array_spec.BoundedArraySpec(
        shape=(1,9), dtype=np.int32, minimum=0, maximum=1, name='observation')
    self._game = game

  def action_spec(self):
    return self._action_spec

  def observation_spec(self):
    return self._observation_spec

  def _reset(self):
    self._game.reset()
    return ts.restart(np.array([self.__state_to_observation()], dtype=np.int32))

  def __state_to_observation(self):
    return list(map(lambda x: int(x), self._game.board_state()))

  def _step(self, action):
    if self._game.game_ended():
      return self.reset()

    result = self._game.mark_spot(action)

    if result == ActionResult.VALID_MOVE:
      return ts.transition(np.array([self.__state_to_observation()], dtype=np.int32), reward=0.1, discount=1.0)
    elif result == ActionResult.BOARD_FULL:
      return ts.termination(np.array([self.__state_to_observation()], dtype=np.int32), 1)
    else:
      return ts.termination(np.array([self.__state_to_observation()], dtype=np.int32), -1)

Usage:

environment = TicTacToeBoardWithNoRulesEnvironment(TicTacToeBoardWithNoRulesBusinessLogic())

Please take a look at the __state_to_observation function, which I use to convert the internal state of the game to values which can be used by Tensorflow. Such a separation allows me to implement a rich domain model of the environment that can be converted to the primitive types used by Tensorflow.

Also, it makes it much easier to write and understand unit tests because in the domain model part, only the business rules are tested, and in the Tensorflow part, I must test only the conversion.

Older post

Deep Q-network terminology in plain English

The terminology used in the paper "Human-level control through deep reinforcement learning"

Newer post

How to use a behavior policy with Tensorflow Agents

Random and scripted behavior policies