Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 201 additions & 58 deletions PyRDF/backend/AWS.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,35 @@
import json
import logging
import time
import concurrent.futures
import os
import sys
import ROOT
from pathlib import Path

import boto3
import botocore
import cloudpickle as pickle

from PyRDF.backend.Dist import Dist
from .Dist import Dist

lambda_await_thread_stop = False


class FlushingLogger:
def __init__(self):
self.logger = logging.getLogger()

def __getattr__(self, name):
method = getattr(self.logger, name)
if name in ['info', 'warning', 'debug', 'error', 'critical', ]:
def flushed_method(msg, *args, **kwargs):
method(msg, *args, **kwargs)
for h in self.logger.handlers:
h.flush()
return flushed_method
else:
return method


class AWS(Dist):
Expand All @@ -25,7 +49,7 @@ def __init__(self, config={}):
more support will be added in future.
"""
super(AWS, self).__init__(config)
self.logger = logging.getLogger()
self.logger = FlushingLogger() if logging.root.level >= logging.INFO else logging.getLogger()
self.npartitions = self._get_partitions()
self.region = config.get('region') or 'us-east-1'

Expand All @@ -35,27 +59,21 @@ def _get_partitions(self):
def ProcessAndMerge(self, mapper, reducer):
"""
Performs map-reduce using AWS Lambda.

Args:
mapper (function): A function that runs the computational graph
and returns a list of values.

reducer (function): A function that merges two lists that were
returned by the mapper.

Returns:
list: A list representing the values of action nodes returned
after computation (Map-Reduce).
"""

ranges = self.build_ranges()

def encode_object(object_to_encode) -> str:
return str(base64.b64encode(pickle.dumps(object_to_encode)))

# Make mapper and reducer transferable
pickled_mapper = encode_object(mapper)
pickled_reducer = encode_object(reducer)
pickled_mapper = AWS.encode_object(mapper)
pickled_reducer = AWS.encode_object(reducer)

# Setup AWS clients
s3_resource = boto3.resource('s3', region_name=self.region)
Expand All @@ -64,48 +82,49 @@ def encode_object(object_to_encode) -> str:
ssm_client = boto3.client('ssm', region_name=self.region)

# Check for existence of infrastructure
"""
s3_output_bucket = ssm_client.get_parameter(Name='output_bucket')['Parameter']['Value']
if not s3_output_bucket:
self.logger.info('AWS backend not initialized!')
return False

ssm_client.put_parameter(
Name='ranges_num',
Type='String',
Value=str(len(ranges)),
Overwrite=True
)

ssm_client.put_parameter(
Name='reducer',
Type='String',
Value=str(pickled_reducer),
Overwrite=True
)
"""

self.logger.info(f'Before lambdas invoke. Number of lambdas: {len(ranges)}')

def invoke_root_lambda(client, root_range, script):
payload = json.dumps({
'range': encode_object(root_range),
'script': script,
'start': str(root_range.start),
'end': str(root_range.end),
'filelist': str(root_range.filelist),
'friend_info': encode_object(root_range.friend_info)
})
return client.invoke(
FunctionName='root_lambda',
InvocationType='Event',
Payload=bytes(payload, encoding='utf8')
)
processing_bucket = ssm_client.get_parameter(Name='processing_bucket')['Parameter']['Value']

s3_resource.Bucket(processing_bucket).objects.all().delete()

invoke_begin = time.time()
# Invoke workers with ranges and mapper
call_results = []
for root_range in ranges:
call_result = invoke_root_lambda(lambda_client, root_range, pickled_mapper)
call_results.append(call_result)

wait_begin = time.time()
global lambda_await_thread_stop
lambda_await_thread_stop = False

with concurrent.futures.ThreadPoolExecutor(max_workers=len(ranges)) as executor:
executor.submit(AWS.wait_for_all_lambdas, s3_client, processing_bucket, len(ranges), self.logger)
futures = [executor.submit(AWS.invoke_root_lambda, root_range, pickled_mapper, self.region, self.logger)
for root_range in ranges]
call_results = [future.result() for future in futures]
if not all(call_results):
lambda_await_thread_stop = True

if lambda_await_thread_stop:
raise Exception(f'Some lambdas failed after multiple retrials')

self.logger.info('All lambdas have been invoked')

# while True:
# results = s3.list_objects_v2(Bucket=s3_output_bucket, Prefix='out.pickle')
Expand All @@ -115,48 +134,172 @@ def invoke_root_lambda(client, root_range, script):
# time.sleep(1)
# result = s3.get_object(s3_output_bucket, 'out.pickle')

processing_bucket = ssm_client.get_parameter(Name='processing_bucket')['Parameter']['Value']
download_begin = time.time()

# Wait until all lambdas finished execution
while True:
results = s3_client.list_objects_v2(Bucket=processing_bucket)
if results['KeyCount'] == len(ranges):
break
self.logger.info(f'Lambdas finished: {results["KeyCount"]}')
time.sleep(1)
# Get names of output files, download and reduce them
filenames = AWS.get_all_objects_from_s3_bucket(s3_client, processing_bucket)
self.logger.info(f'Lambdas finished: {len(filenames)}')

tmp_files_directory = '/tmp'
AWS.remove_all_tmp_root_files(tmp_files_directory)

with concurrent.futures.ThreadPoolExecutor(max_workers=len(filenames)) as executor:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this construction repeats multiple times, could go as separate generic function

futures = [executor.submit(AWS.get_from_s3, filename, self.region, processing_bucket, tmp_files_directory)
for filename in filenames]
files = [future.result() for future in futures]

reduce_begin = time.time()
# Get names of output files, download and reduce them
filenames = s3_client.list_objects_v2(Bucket=processing_bucket)['Contents']

# need better way to do that
accumulator = pickle.loads(s3_client.get_object(
Bucket=processing_bucket,
Key=filenames[0]['Key']
)['Body'].read())
to_process = files
while len(to_process) > 1:
even_index_files = [to_process[i] for i in range(len(to_process)) if i % 2 == 0]
odd_index_files = [to_process[i] for i in range(len(to_process)) if i % 2 == 1]

for filename in filenames[1:]:
file = pickle.loads(s3_client.get_object(
Bucket=processing_bucket,
Key=filename['Key']
)['Body'].read())
accumulator = reducer(accumulator, file)
with concurrent.futures.ThreadPoolExecutor(len(to_process)) as executor:
futures = [executor.submit(reducer, pair[0], pair[1]) for pair in
zip(even_index_files, odd_index_files)]
to_process = [future.result() for future in futures]

if len(even_index_files) > len(odd_index_files):
to_process.append(even_index_files[-1])
elif len(even_index_files) < len(odd_index_files):
to_process.append(odd_index_files[-1])

reduction_result = to_process[0]

# Clean up intermediate objects after we're done
s3_resource.Bucket(processing_bucket).objects.all().delete()

bench = (
len(ranges),
wait_begin-invoke_begin,
reduce_begin-wait_begin,
time.time()-reduce_begin
download_begin - invoke_begin,
reduce_begin - download_begin,
time.time() - reduce_begin
)

print(bench)

return accumulator
# reduced_output = pickle.loads(result)
# return reduced_output
return reduction_result

def distribute_files(self, includes_list):
pass

@staticmethod
def encode_object(object_to_encode) -> str:
return str(base64.b64encode(pickle.dumps(object_to_encode)))

@staticmethod
def get_from_s3(filename, region, bucket_name, directory):
s3_client = boto3.client('s3', region_name=region)
local_filename = os.path.join(directory, filename['Key'])
s3_client.download_file(bucket_name, filename['Key'], local_filename)

tfile = ROOT.TFile(local_filename, 'OPEN')
result = []
# Get all objects from TFile
for key in tfile.GetListOfKeys():
result.append(key.ReadObj())
result[-1].SetDirectory(0)
tfile.Close()

# Remove temporary root file
Path(local_filename).unlink()

return result

@staticmethod
def get_all_objects_from_s3_bucket(s3_client, bucket_name):
response = s3_client.list_objects_v2(Bucket=bucket_name)
result = response.get('Contents', [])
while response.get('IsTruncated'):
cont_token = response.get('NextContinuationToken')
response = s3_client.list_objects_v2(Bucket=bucket_name, ContinuationToken=cont_token)
result += response.get('Contents', [])
return result

@staticmethod
def invoke_root_lambda(root_range, script, region, logger):
"""
Invoke root lambda.
Args:
root_range (Range): Range of data.
script (function): A function that performs an operation on
a range of data.
region (str): String containing AWS region.
logger (logging.Logger):
Returns:
bool: True if lambda invocation and execution was
successful.
"""

trials = 3
client = boto3.client('lambda', region_name=region)

payload = json.dumps({
'range': AWS.encode_object(root_range),
'script': script,
'start': str(root_range.start),
'end': str(root_range.end),
'filelist': str(root_range.filelist),
'friend_info': AWS.encode_object(root_range.friend_info)
})

# Maybe here give info about number of invoked lambda for awsmonitor

while trials >= 0:
trials -= 1
try:
response = client.invoke(
FunctionName='root_lambda',
InvocationType='RequestResponse',
Payload=bytes(payload, encoding='utf8')
)

try:
response['Payload'] = json.loads(response['Payload'].read())
except Exception:
response['Payload'] = {}

if 'FunctionError' in response or response['Payload'].get('statusCode') == 500:
try:
# Get error specification and remove additional quotas (side effect of serialization)
error_type = response['Payload']['errorType'][1:-1]
error_message = response['Payload']['errorMessage'][1:-1]
exception = getattr(sys.modules['builtins'], error_type)
msg = f"Lambda raised an exception: {error_message}"
except Exception:
exception = RuntimeError
msg = (f"Lambda raised an exception: (type={response['Payload']['errorType']},"
f"message={response['Payload']['errorMessage']})")
raise exception(msg)

except botocore.exceptions.ClientError as error:
# AWS site errors
logger.warning(error['Error']['Message'])
except Exception as error:
# All other errors
logger.warning(str(error) + " (" + type(error).__name__ + ")")
else:
return True
time.sleep(1)

# Note: lambda finishes before s3 object is created
return False

@staticmethod
def remove_all_tmp_root_files(directory):
for file in os.listdir(directory):
if file.endswith('.root'):
Path(os.path.join(directory, file)).unlink()

@staticmethod
def wait_for_all_lambdas(s3_client, processing_bucket, num_of_lambdas, logger):
# Wait until all lambdas finished execution
global lambda_await_thread_stop
while not lambda_await_thread_stop:
results = AWS.get_all_objects_from_s3_bucket(s3_client, processing_bucket)
logger.info(f'Lambdas finished: {len(results)}')
if len(results) == num_of_lambdas:
break
time.sleep(1)