# Copyright (c) 2018 Stephen Bunn (stephen@bunn.io)
# MIT License <https://opensource.org/licenses/MIT>
import os
import abc
import enum
import time
import uuid
import shutil
import itertools
from typing import (Any, List, Tuple, Callable,)
from tempfile import (TemporaryDirectory,)
from concurrent.futures import (ThreadPoolExecutor,)
from .. import (__version__,)
from ..content import (Content,)
import blinker
[docs]class DownloadState(enum.Enum):
""" An enum of allowed download states.
Values:
- ``STOPPED``: indicates the download is stopped (error occured)
- ``RUNNING``: indicates the download is running
- ``PREPARING``: indicates the download is starting up
- ``FINISHED``: indicates the download is finished (successfully)
"""
STOPPED = 'stopped'
RUNNING = 'running'
PREPARING = 'preparing'
FINISHED = 'finished'
[docs]class BaseDownloader(abc.ABC):
""" The base abstract base downloader.
`All downloaders must extend from this class.`
"""
on_progress = blinker.Signal()
@property
def download_state(self):
""" dict[str,DownloadState]: The download state dictionary.
"""
if not hasattr(self, '_download_state'):
self._download_state = {}
return self._download_state
@property
def progress_store(self):
""" dict[str,int]: The downloaded content size for progress reporting.
"""
if not hasattr(self, '_progress_store'):
self._progress_store = {}
return self._progress_store
@abc.abstractclassmethod
def can_handle(cls, content: Content):
raise NotImplementedError()
@abc.abstractmethod
def handle_download(self, source: str, url: str, to_path: str) -> str:
raise NotImplementedError()
def _calc_ranges(
self, content_length: int, max_connections: int
) -> List[Tuple[int, int]]:
""" Calculates byte ranges given a content size and the number of \
allowed connections.
Args:
content_length (int): The total size of the content to download.
max_connections (int): The maximum allowed connections to use.
Returns:
list[tuple[int, int]]: A list of size ``max_connections`` tuple \
``(start, end)`` byte ranges.
"""
(start, end,) = itertools.tee(list(range(
0, content_length,
(content_length // max_connections)
)) + [content_length])
next(end, None)
ranges = list(zip(start, end))
if len(ranges) > max_connections:
ranges[-2] = (ranges[-2][0], ranges[-1][-1],)
del ranges[-1]
return ranges
[docs] def handle_progress(
self, download_id: str, content_length: int,
update_delay: float=0.1
):
""" The progress reporting handler.
Args:
download_id (str): The unique id of the download request.
content_length (int): The total size of the downloading content.
update_delay (float, optional): The frequency (in seconds) which \
progress updates are emitted.
"""
try:
# setup sync values if they don't exists (race-condition fix)
if download_id not in self.download_state:
self.download_state[download_id] = DownloadState.PREPARING
if download_id not in self.progress_store:
self.progress_store[download_id] = 0
while True:
if self.progress_store[download_id] < content_length:
self.on_progress.send(
download_id,
current=self.progress_store[download_id],
total=content_length
)
elif self.progress_store[download_id] >= content_length or \
self.download_state[download_id] == \
DownloadState.STOPPED:
break
time.sleep(update_delay)
finally:
del self.progress_store[download_id]
self.on_progress.send(
download_id,
current=content_length,
total=content_length
)
[docs] def download(
self, content: Content, to_path: str,
max_fragments: int=1, max_connections: int=8,
progress_hook: Callable[[Any], None]=None, update_delay: float=0.1,
) -> str:
""" The simplified download method.
Note:
The ``max_fragments`` and ``max_connections`` rules imply that
potentially ``(max_fragments * max_connections)`` connections
from the local system's IP can exist at any time.
Many hosts will flag/ban IPs which utilize more than 10
connections for a single resource.
**For this reason**, ``max_fragments`` and ``max_connections`` are
set to 1 and 8 respectively by default.
Args:
content (Content): The content instance to download.
to_path (str): The path to save the resulting download to.
max_fragments (int, optional): The number of fragments to process
in parallel.
max_connections (int, optional): The number of connections to
allow for downloading a single fragment.
progress_hook (callable, optional): A progress hook that accepts
the arguments ``(download_id, current_size, total_size)`` for
progress updates.
update_delay (float, optional): The frequency (in seconds) where
progress updates are sent to the given ``progress_hook``.
Returns:
str: The downloaded file's local path.
Examples:
Basic usage where ``$HOME`` is the home directory of the
currently executing user.
>>> import os
>>> from qetch.extractors import (GfycatExtractor,)
>>> from qetch.downloaders import (HTTPDownloader,)
>>> content = next(GfycatExtractor().extract(GFYCAT_URL))[0]
>>> saved_to = HTTPDownloader().download(
... content,
... os.path.expanduser('~/Downloads/saved_content.mp4'))
>>> print(saved_to)
$HOME/Downloads/saved_content.mp4
Similar basic usage, but with a given progress hook sent updates
every 0.1 seconds.
>>> def progress(download_id, current, total):
... print(f'{((current / total) * 100.0):6.2f}')
>>> saved_to = HTTPDownloader().download(
... content,
... os.path.expanduser('~/Downloads/saved_content.mp4'),
... progress_hook=progress,
... update_delay=0.1)
0.00
0.00
23.01
54.32
73.09
90.49
97.12
100.00
>>> print(saved_to)
$HOME/Downloads/saved_content.mp4
"""
assert (max_fragments > 0), (
f"'max_fragments' must be at least 1, received {max_fragments!r}"
)
assert (max_connections > 0), (
f"'max_connections' must be at least 1, received "
f"{max_connections!r}"
)
# generate unique download id for state & progress syncing
download_id = str(uuid.uuid4())
with TemporaryDirectory(
prefix=f'{__version__.__name__}[{download_id}]-',
) as temporary_dir:
# +1 worker is for progress handler
with ThreadPoolExecutor(
max_workers=(max_fragments + 1)
) as executor:
if callable(progress_hook):
self.on_progress.connect(progress_hook)
executor.submit(
self.handle_progress,
*(download_id, content.get_size()),
**{'update_delay': update_delay}
)
download_futures = []
for (fragment_idx, fragment,) in enumerate(content.fragments):
download_futures.append(executor.submit(
self.handle_download,
*(
download_id, fragment,
os.path.join(temporary_dir, str(fragment_idx))
),
**{'max_connections': max_connections}
))
# FIXME: handle KeyboardInterrupt with parent thread correctly
try:
while all(future.running() for future in download_futures):
time.sleep(update_delay)
self.download_state[download_id] = DownloadState.FINISHED
except Exception as exc:
self.download_state[download_id] = DownloadState.STOPPED
raise exc
# apply content extractors merge and move result (one step)
shutil.move(
content.extractor.merge([
future.result()
for future in download_futures
]),
to_path
)
return to_path