SNAKE_QTABLE_PATH = '/content/gdrive/MyDrive/snake.qtable.json.gz'
SNAKE_QSTORY_PATH = SNAKE_QTABLE_PATH[:30] + 'qstory.chckpnt.json'
from google.colab import drive
drive.mount(SNAKE_QTABLE_PATH[:15])
Mounted at /content/gdrive
md5 = !md5sum 'snake_game.py'
assert 'fe6f1b08c885095e7425b25c5a7ad9e5' == md5[0].split()[0]
import numpy as np
import matplotlib.pyplot as plt
import copy
import json
import gzip
import hashlib
from snake_game import SnakeGame
from datetime import datetime, timedelta
env = SnakeGame()
def extract_features(game=env):
snake, food = game.get_state()
obsticle = lambda k: copy.deepcopy(game).step(k).done
obs = sum(v << i for i, v in enumerate(obsticle(k) for k in game.ACTION))
return obs, snake.head.h, snake.head.w, food.h, food.w
extract_features()
(0, 0, 14, 7, 31)
Now, we'll create our Q-table, to know how much rows (states) and columns (actions) we need, we need to calculate the action_size and the state_size
def save_checkpoint(qtable, history):
data = qtable.tolist(), history
md5 = hashlib.md5()
for part in data:
for item in part:
md5.update(json.dumps(item).encode())
with open(SNAKE_QSTORY_PATH, 'w') as f:
json.dump([data, md5.hexdigest()], f)
!cp "$SNAKE_QSTORY_PATH" "$SNAKE_QSTORY_PATH".bak
def maybe_load_checkpoint():
for candidate in SNAKE_QSTORY_PATH, SNAKE_QSTORY_PATH + '.bak':
try:
with open(candidate, 'r') as f:
data, val_md5 = json.load(f)
md5 = hashlib.md5()
for part in data:
for item in part:
md5.update(json.dumps(item).encode())
if val_md5 == md5.hexdigest():
print('Qtable and history loaded from', candidate[24:], val_md5)
qtable, history = data
qtable = np.array(qtable)
return qtable, history
except:
pass
print("Training from scratch")
action_space_size = len(env.ACTION)
qtable = np.random.rand(
1 << action_space_size,
env.board_height, env.board_width,
env.board_height, env.board_width,
action_space_size)
return qtable, []
qtable, history = maybe_load_checkpoint()
qtable.shape, qtable.size, len(history)
Qtable and history loaded from snake.qstory.chckpnt.json 6734b06ae619485145c0f8a968cb0c3d
((16, 16, 32, 16, 32, 4), 16777216, 54754507)
plt.rc('figure', figsize=(13, 5))
balance = [np.sum(qtable.argmax(axis=-1) == i.value) for i in env.ACTION]
plt.bar(range(len(balance)), balance, tick_label=[i.name for i in env.ACTION]) and None
ALPHA = 0.2 # Learning rate
GAMMA = 0.9 # Discounting rate
# Exploration parameters
max_epsilon = 1.0 # Exploration probability at start
min_epsilon = 0.01 # Minimum exploration probability
epsilon_decay = 1e-07 # Exponential decay rate for exploration prob
time_limit = timedelta(hours=11, minutes=30) # colab 12 hours limit
stop_time = datetime.now() + time_limit
# 2 For life or until session time expired
while datetime.now() < stop_time:
# Reset the environment
env.reset()
state = extract_features()
done = False
total_rewards = 0
total_steps = 0
# Reduce epsilon (because we need less and less exploration)
epsilon = (max_epsilon - min_epsilon) * np.exp(-epsilon_decay * len(history))
epsilon += min_epsilon
while not done:
# 3. Choose an action a in the current world state (s)
# First we randomize a number
exp_exp_tradeoff = np.random.random()
# If this number > greater than epsilon -->
# exploitation (taking the biggest Q value for this state)
if exp_exp_tradeoff > epsilon:
action = np.argmax(qtable[state])
# Else doing a random choice --> exploration
else:
action = env.random_action()
# Take the action (a) and observe the outcome state(s') and reward (r)
_, reward, done, info = env.step(action)
new_state = extract_features()
# Update Q(s,a):= Q(s,a) + lr * [R(s,a) + GAMMA * max Q(s',a') - Q(s,a)]
# qtable[new_state,:] : all the actions we can take from new state
i = state + (action,)
qtable[i] += ALPHA * (reward + GAMMA * qtable[new_state].max() - qtable[i])
total_rewards += reward
total_steps += 1
# Our new state is state
state = new_state
history.append((total_rewards, total_steps, info.score))
if len(history) % 1000 == 0:
h1k = zip(*history[-1000:])
print('\r%7d' % (len(history) // 1000),
'| Epsilon: %0.2f' % epsilon,
'| Reward: %5.2f' % np.average(next(h1k)),
'| Age: %4d' % np.average(next(h1k)),
'| Score: %4d' % np.average(next(h1k)),
'| %d %%' % (100 * (1 - (stop_time - datetime.now()) / time_limit)),
end = '')
save_checkpoint(qtable, history)
55238 | Epsilon: 0.01 | Reward: 13.21 | Age: 184 | Score: 1251 | 99 %
history = np.array(history)
history.shape
(55238661, 3)
def plot_history(index, label):
plt.rc('figure', figsize=(13, 4))
plt.plot(history[:,index])
plt.ylabel(label)
plt.xlabel('Games')
plot_history(0, 'Rewards')
plot_history(1, 'Steps')
plot_history(2, 'Score')
with gzip.open(SNAKE_QTABLE_PATH, 'wt') as f:
json.dump(qtable.argmax(axis=-1).tolist(), f)
!md5sum "$SNAKE_QTABLE_PATH"
8ecd7bb135420f8293c68ea51f7d1659 /content/gdrive/MyDrive/snake.qtable.json.gz