Skip to content

Commit bc6dc24

Browse files
committed
feat: merge PR aimacode#1300
1 parent d6d4991 commit bc6dc24

File tree

1 file changed

+160
-1
lines changed

1 file changed

+160
-1
lines changed

notebook4e.py

+160-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import time
22
from collections import defaultdict
33
from inspect import getsource
4+
import heapq
5+
import random
46

57
import ipywidgets as widgets
68
import matplotlib.pyplot as plt
@@ -10,12 +12,14 @@
1012
from IPython.display import display
1113
from PIL import Image
1214
from matplotlib import lines
15+
from matplotlib.animation import FuncAnimation
1316
from matplotlib.colors import ListedColormap
1417

1518
from games import TicTacToe, alpha_beta_player, random_player, Fig52Extended
1619
from learning import DataSet
1720
from logic import parse_definite_clause, standardize_variables, unify_mm, subst
18-
from search import GraphProblem, romania_map
21+
from search import GraphProblem, romania_map, Problem, Node
22+
from utils import PriorityQueue
1923

2024

2125
# ______________________________________________________________________________
@@ -1156,3 +1160,158 @@ def plot_pomdp_utility(utility):
11561160
plt.text((right + left) / 2 - 0.02, 10, 'Ask')
11571161
plt.text((right + 1) / 2 - 0.07, 10, 'Delete')
11581162
plt.show()
1163+
1164+
# --------------------------- search problems Animation Class. -------------------------------------------------
1165+
1166+
1167+
def transpose(matrix): return list(zip(*matrix))
1168+
1169+
1170+
def straight_line_distance(A, B):
1171+
"Straight-line distance between two points."
1172+
return sum(abs(a - b)**2 for (a, b) in zip(A, B)) ** 0.5
1173+
1174+
def random_lines(X=range(15, 130), Y=range(60), N=150, lengths=range(6, 12)):
1175+
"""The set of cells in N random lines of the given lengths."""
1176+
result = set()
1177+
for _ in range(N):
1178+
x, y = random.choice(X), random.choice(Y)
1179+
dx, dy = random.choice(((0, 1), (1, 0)))
1180+
result |= line(x, y, dx, dy, random.choice(lengths))
1181+
return result
1182+
1183+
def line(x, y, dx, dy, length):
1184+
"""A line of `length` cells starting at (x, y) and going in (dx, dy) direction."""
1185+
return {(x + i * dx, y + i * dy) for i in range(length)}
1186+
1187+
1188+
class AnimateProblem(Problem):
1189+
directions = [(-1, -1), (0, -1), (1, -1),
1190+
(-1, 0), (1, 0),
1191+
(-1, +1), (0, +1), (1, +1)]
1192+
1193+
def __init__(self, solver, weight=1.4,
1194+
height=20, width=40, cell_weights=None,
1195+
initial=(1, 1), goal=(35, 19),
1196+
obstacles=random_lines(X=range(40), Y=range(20), N=80, lengths=range(1, 7))):
1197+
"""Animate the Grid Problem"""
1198+
self.height = height
1199+
self.width = width
1200+
self.initial = initial
1201+
self.goal = goal
1202+
self.obstacles = set(obstacles) - {self.initial, self.goal}
1203+
self.weight = weight
1204+
# We may change the cell_weights in case of Uniform Cost search
1205+
self.cell_weights = cell_weights
1206+
if self.cell_weights is None:
1207+
self.cell_weights = np.ones((self.width+5, self.height+5), dtype=np.int16)
1208+
# Define all the allowed solvers and their f-value function.
1209+
# TODO: Bidirectional Search, Iterative Deepening Search.
1210+
self.SOLVERS = {'astar': (lambda n: n.path_cost + self.h(n)),
1211+
'wastar': (lambda n: n.path_cost + self.weight*self.h(n)),
1212+
'bfs': (lambda n: n.depth),
1213+
'dfs': (lambda n: -n.depth),
1214+
'ucs': (lambda n: n.path_cost),
1215+
'bestfs': (lambda n: self.h(n))
1216+
}
1217+
self.solver_f = self.SOLVERS[solver] # Assign the solver's f-value function
1218+
self.solver = solver
1219+
self.__initial_node = Node(self.initial)
1220+
# Dictionary of reach nodes. Simlar to `explored` set.
1221+
self.reached = {self.initial: self.__initial_node}
1222+
# Frontier of nodes to be explored!
1223+
self.frontier = PriorityQueue(f=self.solver_f)
1224+
self.frontier.append(self.__initial_node)
1225+
# We will draw each frame onto this figure
1226+
self.fig, self.ax = plt.subplots(figsize=(10, 6))
1227+
self.solution = [(-1, -1)]
1228+
self.ax.axis('off')
1229+
self.ax.axis('equal')
1230+
self.done = False
1231+
1232+
def h(self, node): return straight_line_distance(node.state, self.goal)
1233+
1234+
def result(self, state, action):
1235+
"Both states and actions are represented by (x, y) pairs."
1236+
return action if action not in self.obstacles else state
1237+
1238+
def draw_walls(self):
1239+
self.obstacles |= {(i, -2) for i in range(-2, self.width+4)}
1240+
self.obstacles |= {(i, self.height+4) for i in range(-2, self.width+4)}
1241+
self.obstacles |= {(-2, j) for j in range(-2, self.height+5)}
1242+
self.obstacles |= {(self.width+4, j) for j in range(-2, self.height+5)}
1243+
1244+
def actions(self, state):
1245+
"""You can move one cell in any of `directions` to a non-obstacle cell."""
1246+
x, y = state
1247+
return {(x + dx, y + dy) for (dx, dy) in self.directions} - self.obstacles
1248+
1249+
def path_cost(self, c, state1, action, state2):
1250+
"""Return the cost of moving from s to s1"""
1251+
return c + self.cell_weights[state2[0]][state2[1]]
1252+
1253+
def step(self, frame):
1254+
"""
1255+
One step of search algorithm.
1256+
Explore a node in the frontier and plot
1257+
all the scatter plots again to create a frame.
1258+
A collection of these frames will be used to
1259+
create the animation using matplotlib.
1260+
"""
1261+
# If we are done, don't do anything.
1262+
if self.done:
1263+
return self.sc1, self.sc2, self.sc3, self.sc4, self.sc5, self.sc6
1264+
1265+
# Run the search algorithm for a single
1266+
# node in the frontier.
1267+
node = self.frontier.pop()
1268+
self.solution = node.solution()
1269+
if self.goal_test(node.state):
1270+
self.done = True
1271+
else:
1272+
for child in node.expand(self):
1273+
s = child.state
1274+
if s not in self.reached or child.path_cost < self.reached[s].path_cost:
1275+
self.reached[s] = child
1276+
self.frontier.append(child)
1277+
1278+
# Plot all the new states onto our figure
1279+
# and return them to matplotlib for creating animation.
1280+
self.ax.clear()
1281+
self.ax.axis('off')
1282+
self.ax.axis('equal')
1283+
self.sc1 = self.ax.scatter(*transpose(self.obstacles), marker='s', color='darkgrey')
1284+
self.sc2 = self.ax.scatter(*transpose(list(self.reached)), 1**2, marker='.', c='blue')
1285+
self.sc3 = self.ax.scatter(*transpose(self.solution), marker='s', c='blue')
1286+
self.sc4 = self.ax.scatter(*transpose([node.state]), 9**2, marker='8', c='yellow')
1287+
self.sc5 = self.ax.scatter(*transpose([self.initial]), 9**2, marker='D', c='green')
1288+
self.sc6 = self.ax.scatter(*transpose([self.goal]), 9**2, marker='8', c='red')
1289+
plt.title("Explored: {}, Path Cost: {}\nSolver: {}".format(len(self.reached), node.path_cost, self.solver))
1290+
return self.sc1, self.sc2, self.sc3, self.sc4, self.sc5, self.sc6
1291+
1292+
def run(self, frames=200):
1293+
"""
1294+
Run the main loop of the problem to
1295+
create an animation. If you are running
1296+
on your local machine, you can save animations
1297+
in you system by using the following commands:
1298+
First, you need to download the ffmpeg using:
1299+
Linux/MacOS: `sudo apt install ffmpeg`
1300+
Then you can use the following line of code to generate
1301+
a video of the animation.
1302+
Linux/MacOS : `anim.save('animation.mp4')`
1303+
For Windows users, the process is a little longer:
1304+
Download ffmpeg by following this article: https://www.wikihow.com/Install-FFmpeg-on-Windows
1305+
Then the animation can be saved in a video format as follows:
1306+
Windows: `anim.save('animation.mp4')`
1307+
1308+
If the animation is not complete, increase the number
1309+
of frames in the below lines of code.
1310+
"""
1311+
anim = FuncAnimation(self.fig, self.step, blit=True, interval=200, frames=frames)
1312+
# If you want to save your animations, you can comment either
1313+
# of the lines below.
1314+
# NOTE: FFmpeg is needed to render a .mp4 video of the animation.
1315+
# anim.save('astar.mp4')
1316+
# anim.save('animation.html')
1317+
return HTML(anim.to_html5_video())

0 commit comments

Comments
 (0)