diff --git a/src/main/python/feature-extractor/server.py b/src/main/python/feature-extractor/server.py index f407bb9b..3523b460 100644 --- a/src/main/python/feature-extractor/server.py +++ b/src/main/python/feature-extractor/server.py @@ -3,9 +3,11 @@ sys.path.append('./pb') from concurrent import futures +from multiprocessing import Pool import argparse import logging import os +import signal import time import grpc @@ -14,93 +16,133 @@ import pb.service_pb2 as service_pb2 import pb.service_pb2_grpc as service_pb2_grpc +# isn't used but re-exported so it can be used in tests +from pb.service_pb2 import gopkg_dot_in_dot_bblfsh_dot_sdk_dot_v1_dot_uast_dot_generated__pb2 as uast_pb _ONE_DAY_IN_SECONDS = 60 * 60 * 24 +# keep extractors out of the Service class to be able to pickle them +# return list instead of iterator for pickle also + + +def _identifiers_extractor(uast, options): + return list( + IdentifiersBagExtractor( + docfreq_threshold=options.docfreqThreshold, + split_stem=options.splitStem, + weight=options.weight or 1).extract(uast)) + + +def _literals_extractor(uast, options): + return list( + LiteralsBagExtractor( + docfreq_threshold=options.docfreqThreshold, + weight=options.weight or 1).extract(uast)) + + +def _uast2seq_extractor(uast, options): + seq_len = list(options.seqLen) if options.seqLen else None + + return list( + UastSeqBagExtractor( + docfreq_threshold=options.docfreqThreshold, + weight=options.weight or 1, + stride=options.stride or 1, + seq_len=seq_len or 5).extract(uast)) + + +def _graphlet_extractor(uast, options): + return list( + GraphletBagExtractor( + docfreq_threshold=options.docfreqThreshold, + weight=options.weight or 1).extract(uast)) + + +def _features_from_iter(f_iter): + return [service_pb2.Feature(name=f[0], weight=f[1]) for f in f_iter] + class Service(service_pb2_grpc.FeatureExtractorServicer): """Feature Extractor Service""" - extractors_names = ["identifiers", "literals", "uast2seq", "graphlet"] + pool = None + extractors = { + "identifiers": _identifiers_extractor, + "literals": _literals_extractor, + "uast2seq": _uast2seq_extractor, + "graphlet": _graphlet_extractor, + } + + def __init__(self, pool): + super(Service, self).__init__() + self.pool = pool def Extract(self, request, context): """ Extract features using multiple extrators """ - extractors = [] + results = [] - for name in self.extractors_names: + for name in self.extractors: if request.HasField(name): options = getattr(request, name, None) if options is None: continue - constructor = getattr(self, "_%s_extractor" % name) - extractors.append(constructor(options)) + + result = self.pool.apply_async(self.extractors[name], + (request.uast, options)) + results.append(result) features = [] - for ex in extractors: - features.extend(_features_iter_to_list(ex.extract(request.uast))) + for result in results: + features.extend(_features_from_iter(result.get())) return service_pb2.FeaturesReply(features=features) def Identifiers(self, request, context): """Extract identifiers weighted set""" - it = self._identifiers_extractor(request.options).extract(request.uast) - return service_pb2.FeaturesReply(features=_features_iter_to_list(it)) + it = self.pool.apply(_identifiers_extractor, + (request.uast, request.options)) + return service_pb2.FeaturesReply(features=_features_from_iter(it)) def Literals(self, request, context): """Extract literals weighted set""" - it = self._literals_extractor(request.options).extract(request.uast) - return service_pb2.FeaturesReply(features=_features_iter_to_list(it)) + it = self.pool.apply(_literals_extractor, + (request.uast, request.options)) + return service_pb2.FeaturesReply(features=_features_from_iter(it)) def Uast2seq(self, request, context): """Extract uast2seq weighted set""" - it = self._uast2seq_extractor(request.options).extract(request.uast) - return service_pb2.FeaturesReply(features=_features_iter_to_list(it)) + it = self.pool.apply(_uast2seq_extractor, + (request.uast, request.options)) + return service_pb2.FeaturesReply(features=_features_from_iter(it)) def Graphlet(self, request, context): """Extract graphlet weighted set""" - it = self._graphlet_extractor(request.options).extract(request.uast) - return service_pb2.FeaturesReply(features=_features_iter_to_list(it)) - - def _identifiers_extractor(self, options): - return IdentifiersBagExtractor( - docfreq_threshold=options.docfreqThreshold, - split_stem=options.splitStem, - weight=options.weight or 1) - - def _literals_extractor(self, options): - return LiteralsBagExtractor( - docfreq_threshold=options.docfreqThreshold, - weight=options.weight or 1) - - def _uast2seq_extractor(self, options): - seq_len = list(options.seqLen) if options.seqLen else None + it = self.pool.apply(_graphlet_extractor, + (request.uast, request.options)) + return service_pb2.FeaturesReply(features=_features_from_iter(it)) - return UastSeqBagExtractor( - docfreq_threshold=options.docfreqThreshold, - weight=options.weight or 1, - stride=options.stride or 1, - seq_len=seq_len or 5) - def _graphlet_extractor(self, options): - return GraphletBagExtractor( - docfreq_threshold=options.docfreqThreshold, - weight=options.weight or 1) - - -def _features_iter_to_list(f_iter): - return [service_pb2.Feature(name=f[0], weight=f[1]) for f in f_iter] +def worker_init(): + """ ignore SIGINT (Ctrl-C) event inside workers. + Read more here: + https://stackoverflow.com/questions/1408356/keyboard-interrupts-with-pythons-multiprocessing-pool + """ + signal.signal(signal.SIGINT, signal.SIG_IGN) def serve(port, workers): logger = logging.getLogger('feature-extractor') - server = _get_server(port, workers) + # processes=None uses os.cpu_count() as a value + pool = Pool(processes=None, initializer=worker_init) + + server = _get_server(port, workers, pool) server.start() logger.info("server started on port %d" % port) @@ -110,19 +152,21 @@ def serve(port, workers): while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: + pool.terminate() server.stop(0) -def _get_server(port, workers): +def _get_server(port, workers, pool): server = grpc.server(futures.ThreadPoolExecutor(max_workers=workers)) - service_pb2_grpc.add_FeatureExtractorServicer_to_server(Service(), server) + service_pb2_grpc.add_FeatureExtractorServicer_to_server( + Service(pool), server) server.add_insecure_port('[::]:%d' % port) return server if __name__ == '__main__': port = int(os.getenv('FEATURE_EXT_PORT', "9001")) - workers = int(os.getenv('FEATURE_EXT_WORKERS', "10")) + workers = int(os.getenv('FEATURE_EXT_WORKERS', "100")) parser = argparse.ArgumentParser(description='Feature Extractor Service.') parser.add_argument( diff --git a/src/main/python/feature-extractor/test_server.py b/src/main/python/feature-extractor/test_server.py index da37ae38..3deae407 100644 --- a/src/main/python/feature-extractor/test_server.py +++ b/src/main/python/feature-extractor/test_server.py @@ -2,15 +2,16 @@ import sys sys.path.append('./pb') +from multiprocessing import Pool import grpc import json import unittest -import pb.service_pb2 as service_pb2 -import pb.service_pb2_grpc as service_pb2_grpc from google.protobuf.json_format import ParseDict as ProtoParseDict -from pb.service_pb2 import gopkg_dot_in_dot_bblfsh_dot_sdk_dot_v1_dot_uast_dot_generated__pb2 as uast_pb -from server import _get_server +# all grpc stuff must be imported from server and not directly from pb package +# otherwise requests will be failing with +# PicklingError: Can't pickle : it's not the same object as ... +from server import _get_server, service_pb2, service_pb2_grpc, uast_pb class TestServer(unittest.TestCase): @@ -24,8 +25,9 @@ def setUp(self): node.ParseFromString(f.read()) self.uast = node + pool = Pool(processes=1) port = get_open_port() - self.server = _get_server(port, 1) + self.server = _get_server(port, 1, pool) self.server.start() channel = grpc.insecure_channel("localhost:%d" % port)