-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo.py
More file actions
193 lines (154 loc) · 6.56 KB
/
demo.py
File metadata and controls
193 lines (154 loc) · 6.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import cv2
import numpy as np
import os
import torch
from conv_net import ConvNet
from detect_balls import find_circles
from find_table_corners import table_corners
from find_best_shot import *
from project_board import *
test_number = 3
ckpt_epoch = 9
player = "solid"
data_dir = "."
model_weights = os.path.join("epoch_%d.pt" % ckpt_epoch)
use_cuda = torch.cuda.is_available()
device = torch.device("cpu")
model = ConvNet().to(device)
model.eval()
model.load_state_dict(torch.load(model_weights, map_location=lambda storage, loc: storage))
cap = cv2.VideoCapture(os.path.join(data_dir, "full_test_%d.mp4" % test_number))
frame_num = 0
while cap.isOpened():
frame_num += 1
ret, frame = cap.read()
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break
h, w = frame.shape[0], frame.shape[1]
# create modified image
pool = np.copy(frame)
# 1. find table region and corners
corners, hls_mask = table_corners(pool)
corners = order_corners(corners)
# draw corners
for corner in corners:
[x,y] = corner
cv2.circle(pool, (x, y), 30, (0,0,255), -1)
# 1b. constrain mask to table
only_table_mask = np.zeros(hls_mask.shape, dtype=np.uint8)
cv2.fillPoly(only_table_mask, [corners], (255))
hls_mask = cv2.bitwise_and(hls_mask, hls_mask, mask=only_table_mask)
# 2. find and classify pool balls
w_x, w_y = None, None
b_x, b_y = None, None
stripe_centers = []
solid_centers = []
circles = find_circles(hls_mask)
if circles is not None:
for circle in circles[0]:
c_x, c_y, radius = circle[0], circle[1], circle[2]
# bounding box region
l1, l2 = int(c_x) - int(radius), int(c_y) - int(radius)
r1, r2 = int(c_x) + int(radius), int(c_y) + int(radius)
l1, l2, r1, r2 = max(0, l1), max(0, l2), min(w, r1), min(h, r2)
bbox = frame[l2 : r2, l1 : r1]
# make prediction with model
bbox = cv2.resize(bbox, (40, 40), interpolation = cv2.INTER_AREA)
bbox_tensor = torch.from_numpy(bbox).unsqueeze(0).to(device, dtype = torch.float)
label = model(bbox_tensor).argmax(dim = 1).item()
if label == 0:
w_x, w_y = c_x, c_y
elif label == 8:
b_x, b_y = c_x, c_y
elif label > 8:
stripe_centers.append([c_x, c_y])
else:
solid_centers.append([c_x, c_y])
# draw the circle and center
cv2.circle(pool, (c_x, c_y), radius, (0, 255, 0), 2)
cv2.circle(pool, (c_x, c_y), 2, (0, 0, 255), 3)
# 3. homographic projection
h = compute_homography(corners)
# for circle in circles[0]:
# circle[:2] = project(circle[:2], h)
stripe_centers = [project(ball, h) for ball in stripe_centers]
solid_centers = [project(ball, h) for ball in solid_centers]
white_center = None
black_center = None
if w_x is not None and w_y is not None:
white_center = project(np.asarray([w_x, w_y]), h)
if b_x is not None and b_y is not None:
black_center = project(np.asarray([b_x, b_y]), h)
# Calculate coordinates of the pockets
pockets = []
for corner in corners:
corner = project(corner, h)
pockets.append(corner)
pockets.append(((pockets[0] + pockets[3]) / 2).astype(int))
pockets.append(((pockets[1] + pockets[2]) / 2).astype(int))
pockets = np.array(pockets)
# 4. shot calculation
# input: list of [x, y] coordinates for pockets, stripes, solids, white, black
# output: [x, y] coordinates for a pocket and a stripes ball
pockets = np_coords_to_points(pockets)
# stripes = np_coords_to_points(np.asarray(stripe_centers))
# solids = np_coords_to_points(np.asarray(solid_centers))
# white = np_coord_to_point(np.asarray([w_x, w_y]))
# black = np_coord_to_point(np.asarray([b_x, b_y]))
stripes = np_coords_to_points(np.asarray(stripe_centers))
solids = np_coords_to_points(np.asarray(solid_centers))
white = np_coord_to_point(white_center)
black = np_coord_to_point(black_center)
# draw pockets
for pocket in pockets:
pocket = point_to_np_coord(pocket)
pocket = unproject(pocket, h)
[x,y] = pocket
cv2.circle(pool, (x, y), 20, (255,0,0), -1)
# print("frame", frame_num)
# for pocket in pockets:
# print("pocket", str(pocket))
# for stripe in stripes:
# print("stripes", str(stripe))
# for solid in solids:
# print("solids", str(solid))
# print("white", str(white))
# print("black", str(black))
# increase last arg to allow for higher angle shots
if player == "solid":
shots = find_direct_shots(white, black, solids, stripes, pockets, 50, 10)
# target_ball, target_pocket, _ = find_closest_shot(white, black, solids, stripes, pockets, 20)
else:
shots = find_direct_shots(white, black, stripes, solids, pockets, 50, 10)
# target_ball, target_pocket, _ = find_closest_shot(white, black, stripes, solids, pockets, 20)
for pocket in shots:
# shot from white -> target_ball -> target_pocket
target_pocket = point_to_np_coord(pocket)
target_ball = point_to_np_coord(shots[pocket])
# 5. projection back to player view
target_ball = unproject(target_ball, h)
target_pocket = unproject(target_pocket, h)
target_ball = (target_ball[0], target_ball[1])
target_pocket = (target_pocket[0], target_pocket[1])
# 6. Draw shot
cv2.line(pool, target_ball, target_pocket, (255, 255, 255), 2) # ball -> hole
cv2.line(pool, (w_x, w_y), target_ball, (255, 255, 255), 2) # cue/white -> ball
# if target_ball is not None and target_pocket is not None:
# # shot from white -> target_ball -> target_pocket
# target_ball = point_to_np_coord(target_ball)
# target_pocket = point_to_np_coord(target_pocket)
# # 5. projection back to player view
# target_ball = unproject(target_ball, h)
# target_pocket = unproject(target_pocket, h)
# target_ball = (target_ball[0], target_ball[1])
# target_pocket = (target_pocket[0], target_pocket[1])
# # 6. Draw shot
# cv2.line(pool, target_ball, target_pocket, (255, 255, 255), 2) # ball -> hole
# cv2.line(pool, (w_x, w_y), target_ball, (255, 255, 255), 2) # cue/white -> ball
#cv2.line(pool, (w_x, w_y), (0, 0), (255, 0, 0), 2)
cv2.imshow('pool frame', pool)
if cv2.waitKey(1) == ord('q'):
break
cap.release()
cv2.destroyAllWindows()