Source code for d3mdm.s3

# -*- coding: utf-8 -*-

import io
import logging
import os
import re
import tarfile

import boto3

LOGGER = logging.getLogger(__name__)


[docs]class S3Manager(object): def __init__(self, bucket, root_dir='datasets', skip_sublevels=False): self.bucket = bucket self.client = boto3.client('s3') self.root_dir = root_dir self.skip_sublevels = skip_sublevels
[docs] def list_objects(self, prefix=''): resp = self.client.list_objects(Bucket=self.bucket, Prefix=prefix) keys = [c['Key'] for c in resp.get('Contents', [])] while resp['IsTruncated']: marker = keys[-1] resp = self.client.list_objects(Bucket=self.bucket, Prefix=prefix, Marker=marker) keys.extend([c['Key'] for c in resp.get('Contents', [])]) return keys
[docs] def load_tar(self, dataset_name, raw, tf): files = tf.getnames() if raw: dataset = os.path.join(dataset_name, dataset_name + '_dataset') problem = os.path.join(dataset_name, dataset_name + '_problem') if self.skip_sublevels: prefixes = [dataset + '/tables/', dataset + '/datasetDoc.json', problem] else: prefixes = [dataset, problem] files = [fn for fn in files if any(fn.startswith(prefix) for prefix in prefixes)] root = dict() for key in files: LOGGER.debug("Getting file {} from tarfile".format(key)) with tf.extractfile(key) as buf: content = buf.read() path, filename = tuple(key.rsplit('/', 1)) data = root for level in path.split('/')[1:]: data = data.setdefault(level, dict()) data[filename] = content return root
[docs] def load(self, dataset_name, raw=False): key = '{}/{}.tar.gz'.format(self.root_dir, dataset_name) LOGGER.info("Getting file {} from bucket {}".format(key, self.bucket)) content = self.client.get_object(Bucket=self.bucket, Key=key) bytes_io = io.BytesIO(content['Body'].read()) with tarfile.open(fileobj=bytes_io, mode='r:gz') as tf: return self.load_tar(dataset_name, raw, tf)
[docs] def write(self, dataset, base_dir): bytes_io = io.BytesIO() with tarfile.open(fileobj=bytes_io, mode='w:gz') as tf: self.write_tar(dataset, base_dir, tf) key = '{}/{}.tar.gz'.format(self.root_dir, base_dir) LOGGER.info("Writing file {} into S3 bucket {}".format(key, self.bucket)) self.client.put_object(Bucket=self.bucket, Key=key, Body=bytes_io.getvalue())
[docs] def write_tar(self, dataset, base_dir, tf): for path, value in dataset.items(): key = os.path.join(base_dir, path) if isinstance(value, dict): self.write_tar(value, key, tf) else: LOGGER.debug("Adding file {} into tarfile".format(key)) info = tarfile.TarInfo(name=key) info.size = len(value) bytes_io = io.BytesIO(value) tf.addfile(info, bytes_io)
[docs] def datasets(self): resp = self.client.list_objects(Bucket=self.bucket, Prefix=self.root_dir) names = [] regex = re.compile(r'{}/(.+)\.tar\.gz'.format(self.root_dir)) for entry in resp.get('Contents', []): key = entry['Key'] match = regex.match(key) if not match: print('WARNING: Invalid dataset name found in S3 bucket {}: {}'.format( self.bucket, key)) else: names.append(match.group(1)) return names
[docs] def exists(self, dataset_name): prefix = '{}/{}.tar.gz'.format(self.root_dir, dataset_name) return 'Contents' in self.client.list_objects(Bucket=self.bucket, Prefix=prefix)