In [1]:
SNAKE_QTABLE_PATH = '/content/gdrive/MyDrive/msp430.snake.qtable.json.gz'
SNAKE_QSTORY_PATH = SNAKE_QTABLE_PATH[:31] + 'qstory.chckpnt.json'
In [2]:
from google.colab import drive
drive.mount(SNAKE_QTABLE_PATH[:15])
Mounted at /content/gdrive
In [4]:
md5 = !md5sum 'snake_game.py'
assert 'fe6f1b08c885095e7425b25c5a7ad9e5' == md5[0].split()[0]
In [5]:
import numpy as np
import matplotlib.pyplot as plt
import copy
import json
import gzip
import hashlib
from snake_game import SnakeGame
from datetime import datetime, timedelta

Step 1: Create the environment 🎮¶

  • Here we'll create the FrozenLake environment.
  • OpenAI Gym is a library composed of many environments that we can use to train our agents.
  • In our case we choose to use Frozen Lake.
In [6]:
env = SnakeGame()
In [7]:
def extract_features(game=env):
  snake, food = game.get_state()
  obsticle = lambda k: copy.deepcopy(game).step(k).done
  o = sum(v << i for i, v in enumerate(obsticle(k) for k in game.ACTION))
  a, b = snake.head.h - food.h, snake.head.w - food.w
  c, d = abs(a) < env.board_height / 2, abs(b) < env.board_width / 2
  a, b = (max(min(1, i), -1) + 1 for i in (a, b))
  c, d = (int(i) for i in (c, d))
  return o, a, b, c, d
In [8]:
extract_features()
Out[8]:
(0, 0, 0, 1, 1)

Step 2: Create the Q-table and initialize it 🗄️¶

Now, we'll create our Q-table, to know how much rows (states) and columns (actions) we need, we need to calculate the action_size and the state_size

In [9]:
def save_checkpoint(qtable, history):
  data = qtable.tolist(), history
  md5  = hashlib.md5(json.dumps(data).encode()).hexdigest()
  with open(SNAKE_QSTORY_PATH, 'w') as f:
    f.write(json.dumps([data, md5]))
  !cp "$SNAKE_QSTORY_PATH" "$SNAKE_QSTORY_PATH".bak
In [10]:
def maybe_load_checkpoint():
  for candidate in SNAKE_QSTORY_PATH, SNAKE_QSTORY_PATH + '.bak':
    try:
      with open(candidate, 'r') as f:
        data, md5 = json.load(f)
      if md5 == hashlib.md5(json.dumps(data).encode()).hexdigest():
        print('Qtable and history loaded from', candidate[24:], md5)
        qtable, history = data
        qtable = np.array(qtable)
        return qtable, history
    except:
      pass
  print("Training from scratch")
  action_space_size = len(env.ACTION)
  qtable = np.random.rand(1 << action_space_size, 3, 3, 2, 2, action_space_size)
  return qtable, []
In [11]:
qtable, history = maybe_load_checkpoint()
qtable.shape, qtable.size, len(history)
Qtable and history loaded from msp430.qstory.chckpnt.json c6ba76c6c764683f80bb8ad226b332c8
Out[11]:
((16, 3, 3, 2, 2, 4), 2304, 2386361)
In [12]:
plt.rc('figure', figsize=(13, 5))
balance = [np.sum(qtable.argmax(axis=-1) == i.value) for i in env.ACTION]
plt.bar(range(len(balance)), balance, tick_label=[i.name for i in env.ACTION]) and None

Step 3: Create the hyperparameters ⚙️¶

  • Here, we'll specify the hyperparameters
In [13]:
ALPHA = 0.2                    # Learning rate
GAMMA = 0.9                    # Discounting rate

# Exploration parameters
max_epsilon = 1.0              # Exploration probability at start
min_epsilon = 0.01             # Minimum exploration probability 
epsilon_decay = 1e-05          # Exponential decay rate for exploration prob

Step 4: The Q learning algorithm 🧠¶

  • Now we implement the Q learning algorithm: Q algo
In [14]:
time_limit = timedelta(hours=11, minutes=50) # colab 12 hours limit
stop_time  = datetime.now() + time_limit

# 2 For life or until session time expired
while datetime.now() < stop_time:

  # Reset the environment
  env.reset()
  state = extract_features()
  done = False
  total_rewards = 0
  total_steps = 0

  # Reduce epsilon (because we need less and less exploration)
  epsilon = (max_epsilon - min_epsilon) * np.exp(-epsilon_decay * len(history))
  epsilon += min_epsilon

  while not done:

    # 3. Choose an action a in the current world state (s)
    # First we randomize a number
    exp_exp_tradeoff = np.random.random()

    # If this number > greater than epsilon -->
    # exploitation (taking the biggest Q value for this state)
    if exp_exp_tradeoff > epsilon:
      action = np.argmax(qtable[state])

    # Else doing a random choice --> exploration
    else:
      action = env.random_action()

    # Take the action (a) and observe the outcome state(s') and reward (r)
    _, reward, done, info = env.step(action)
    new_state = extract_features()

    # Update Q(s,a):= Q(s,a) + lr * [R(s,a) + GAMMA * max Q(s',a') - Q(s,a)]
    # qtable[new_state,:] : all the actions we can take from new state
    i = state + (action,)
    qtable[i] += ALPHA * (reward + GAMMA * qtable[new_state].max() - qtable[i])

    total_rewards += reward
    total_steps += 1

    # Our new state is state
    state = new_state

  history.append((total_rewards, total_steps, info.score))
  if len(history) % 1000 == 0:
    h1k = zip(*history[-1000:])
    print('\r%7d' % (len(history) // 1000),
          '| Epsilon: %0.2f' % epsilon,
          '| Reward: %5.2f'  % np.average(next(h1k)),
          '| Age: %4d'       % np.average(next(h1k)),
          '| Score: %4d'     % np.average(next(h1k)),
          '| %d %%' % (100 * (1 - (stop_time - datetime.now()) / time_limit)),
          end = '')
    
save_checkpoint(qtable, history)
   2605 | Epsilon: 0.01 | Reward: 16.57 | Age:  271 | Score: 1503 | 99 %
In [15]:
plt.rc('figure', figsize=(13, 4))
history = zip(*history)
In [16]:
plt.plot(next(history))
plt.ylabel('Rewards')
plt.xlabel('Games') and None
In [17]:
plt.plot(next(history))
plt.ylabel('Steps')
plt.xlabel('Games') and None
In [18]:
plt.plot(next(history))
plt.ylabel('Score')
plt.xlabel('Games') and None
In [19]:
with gzip.open(SNAKE_QTABLE_PATH, 'wt') as f:
  json.dump(qtable.argmax(axis=-1).tolist(), f)
!md5sum "$SNAKE_QTABLE_PATH"
b8bc01f832c0787147ad0fddbf9c4d76  /content/gdrive/MyDrive/msp430.snake.qtable.json.gz