Source code for atm.data

import glob
import logging
import os
import shutil

import boto3
import pandas as pd
import requests
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ClientError

LOGGER = logging.getLogger('atm')


[docs]def copy_files(extension, source, target=None): """Copy matching files from source to target. Scan the ``source`` folder and copy any file that end with the given ``extension`` to the ``target`` folder. Both ``source`` and ``target`` are expected to be either a ``str`` or a list or tuple of strings to be joined using ``os.path.join``. ``sourec`` will be interpreted as a path relative to the ``atm`` root code folder, and ``target`` will be interpreted as a path relative to the user's current working directory. If ``target`` is ``None``, ``source`` will be used, and if the ``target`` directory does not exist, it will be created. Args: extension (str): File extension to copy. source (str or iterabe): Source directory. target (str or iterabe or None): Target directory. Defaults to ``None``. Returns: dict: Dictionary containing the file names without extension as keys and the new paths as values. """ if isinstance(source, (list, tuple)): source = os.path.join(*source) if isinstance(target, (list, tuple)): target = os.path.join(*target) elif target is None: target = source source_dir = os.path.join(os.path.dirname(__file__), source) target_dir = os.path.join(os.getcwd(), target) if not os.path.exists(target_dir): os.makedirs(target_dir) file_paths = dict() for source_file in glob.glob(os.path.join(source_dir, '*.' + extension)): file_name = os.path.basename(source_file) target_file = os.path.join(target_dir, file_name) print('Generating file {}'.format(target_file)) shutil.copy(source_file, target_file) file_paths[file_name[:-(len(extension) + 1)]] = target_file return file_paths
[docs]def download_demo(datasets, path=None): if not isinstance(datasets, list): datasets = [datasets] if path is None: path = os.path.join(os.getcwd(), 'demos') if not os.path.exists(path): os.makedirs(path) client = boto3.client('s3', config=Config(signature_version=UNSIGNED)) paths = list() for dataset in datasets: save_path = os.path.join(path, dataset) try: LOGGER.info('Downloading {}'.format(dataset)) client.download_file('atm-data', dataset, save_path) paths.append(save_path) except ClientError as e: LOGGER.error('An error occurred trying to download from AWS3.' 'The following error has been returned: {}'.format(e)) return paths[0] if len(paths) == 1 else paths
[docs]def get_demos(args=None): client = boto3.client('s3', config=Config(signature_version=UNSIGNED)) available_datasets = [obj['Key'] for obj in client.list_objects(Bucket='atm-data')['Contents']] return available_datasets
def _download_from_s3(path, local_path, aws_access_key=None, aws_secret_key=None, **kwargs): client = boto3.client( 's3', aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key, ) bucket = path.split('/')[2] file_to_download = path.replace('s3://{}/'.format(bucket), '') try: LOGGER.info('Downloading {}'.format(path)) client.download_file(bucket, file_to_download, local_path) return local_path except ClientError as e: LOGGER.error('An error occurred trying to download from AWS3.' 'The following error has been returned: {}'.format(e)) def _download_from_url(url, local_path, **kwargs): data = requests.get(url).text with open(local_path, 'wb') as outfile: outfile.write(data.encode()) LOGGER.info('File saved at {}'.format(local_path)) return local_path DOWNLOADERS = { 's3': _download_from_s3, 'http': _download_from_url, 'https': _download_from_url, } def _download(path, local_path, **kwargs): protocol = path.split(':', 1)[0] downloader = DOWNLOADERS.get(protocol) if not downloader: raise ValueError('Unknown protocol: {}'.format(protocol)) return downloader(path, local_path, **kwargs) def _get_local_path(name, path, aws_access_key=None, aws_secret_key=None): if os.path.isfile(path): return path cwd = os.getcwd() data_path = os.path.join(cwd, 'data') if not name.endswith('csv'): name = name + '.csv' local_path = os.path.join(data_path, name) if os.path.isfile(local_path): return local_path if not os.path.isfile(local_path): if not os.path.exists(data_path): os.makedirs(data_path) _download(path, local_path, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key) return local_path
[docs]def load_data(name, path, aws_access_key=None, aws_secret_key=None): """Load data from the given path. If the path is an URL or an S3 path, download it and make a local copy of it to avoid having to dowload it later again. Args: name (str): Name of the dataset. Used to cache the data locally. path (str): Local path or S3 path or URL. aws_access_key (str): AWS access key. Optional. aws_secret_key (str): AWS secret key. Optional. Returns: pandas.DataFrame: The loaded data. """ local_path = _get_local_path( name, path, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key) return pd.read_csv(local_path).dropna(how='any')