-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrecognize_command.py
More file actions
175 lines (136 loc) · 4.64 KB
/
recognize_command.py
File metadata and controls
175 lines (136 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import time
import math
import struct
import wave
import sys
# TODO: Refactor this spaghetti
PATH = os.path.dirname(os.path.realpath(__file__))
sys.path.append('..')
sys.path.append(PATH)
import pyaudio
import pandas as pd
import torch
import torch.nn.functional as F
import torchaudio
from model import SpeechRecognitionModel
import data_processing as dp
# Logger Import
from local_server.logger import Logger
# Audio configs vars
from audio_processing.audio_confs import *
SHORT_NORMALIZE = (1.0/32768.0)
TIMEOUT_LENGTH = 2
MODEL_PATH = f'{PATH}\\command_model_trained.pth'
CSV_PATH = f'{PATH}\\data\\command_labels.csv'
WAV_PATH = f'{PATH}\\temp.wav'
# Logger Init
logger = Logger()
try:
logger.connect()
logger_connected = True
except Exception:
# print('An exception occured while trying to connect to logger')
logger_connected = False
def _print_log(data):
print(data)
sys.stdout.flush()
if logger_connected:
logger.log(f'RECO: {data}')
class Listener:
@staticmethod
def rms(frame):
count = len(frame) / S_WIDTH
format = "%dh" % (count)
shorts = struct.unpack(format, frame)
sum_squares = 0.0
for sample in shorts:
n = sample * SHORT_NORMALIZE
sum_squares += n * n
rms = math.pow(sum_squares / count, 0.5)
return rms * 1000
def __init__(self):
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=FORMAT,
channels=CHANNELS,
rate=S_RATE,
input=True,
output=True,
frames_per_buffer=CHUNK)
def record(self):
_print_log('\nAudio detected, recording now...\n')
rec = []
current = time.time()
end = time.time() + TIMEOUT_LENGTH
while current <= end:
data = self.stream.read(CHUNK)
if self.rms(data) >= THRESHOLD:
end = time.time() + TIMEOUT_LENGTH
current = time.time()
rec.append(data)
return b''.join(rec)
def listen(self):
_print_log('\nNow Listening...\n')
while True:
input = self.stream.read(CHUNK)
rms_val = self.rms(input)
if rms_val > THRESHOLD:
return self.record()
def write(self, filename, recording):
wf = wave.open(filename, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(self.p.get_sample_size(FORMAT))
wf.setframerate(S_RATE)
wf.writeframes(recording)
wf.close()
class CommandRecognizer():
def __init__(self):
# Initialising a new model to recognize
self.model = SpeechRecognitionModel()
# Loading a pre-trained model
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model.eval()
# Loading command labels
self.commands = list(pd.read_csv(CSV_PATH).iloc[:, 1])
def recognize(self):
# Listening and writing to a temp wav file
l = Listener()
wav_data = l.listen()
start_t = time.time()
l.write(WAV_PATH, wav_data)
# Reading wav file as a tensor
# TODO: Read directly from the recorded frames
waveform, _ = torchaudio.load_wav(WAV_PATH)
# Data-preprocecessing
specs = dp.preprocess_reco(waveform)
# Recognising
out = self.model(specs)
out = F.log_softmax(out, dim=2)
out = out.transpose(0, 1)
# Data-postprocessing
pred = dp.postprocess_reco(out.transpose(0, 1))
pred_str = ''.join(pred)
result_dict = {}
for command in self.commands:
result_dict[command] = dp.cer(command, pred_str)
min_cer = min(result_dict.values())
# Results
accuracy = 1 - min_cer
result = ''.join([key for key in result_dict if result_dict[key] == min_cer])
if min_cer < 0.5:
_print_log('Command Recognised -')
_print_log(result)
_print_log('Accuracy Score - {:0.4f}\n'.format(accuracy))
elif 0.5 < min_cer < 0.7:
_print_log('Command Partially Recognised -')
_print_log(result)
_print_log('Accuracy Score - {:0.4f}\n'.format(accuracy))
else:
_print_log('Not able to recognize the command. Try again.')
return None, accuracy
_print_log("\n--- Total Recognition Time: {:0.4f} seconds ---\n".format(time.time() - start_t))
return result, accuracy
if __name__ == "__main__":
recognizer = CommandRecognizer()
while True:
recognizer.recognize()