Parallel reading from stdin and uploading to Pashka's servers (using asyncio)

This commit is contained in:
Egor Aristov 2020-06-07 09:45:27 +03:00
parent ce8489af1a
commit 5dca2487cd

View File

@ -1,3 +1,4 @@
import asyncio
import datetime import datetime
import hashlib import hashlib
import math import math
@ -7,10 +8,12 @@ import textwrap
import time import time
import typing import typing
from collections import namedtuple from collections import namedtuple
from concurrent.futures.thread import ThreadPoolExecutor
from functools import wraps
import click import click
from telethon.helpers import generate_random_long from telethon.helpers import generate_random_long
from telethon.sync import TelegramClient, connection from telethon import TelegramClient, connection
from telethon.tl.custom.message import Message from telethon.tl.custom.message import Message
from telethon.tl.functions.upload import SaveBigFilePartRequest from telethon.tl.functions.upload import SaveBigFilePartRequest
from telethon.tl.functions.upload import GetFileRequest from telethon.tl.functions.upload import GetFileRequest
@ -36,6 +39,14 @@ class Config:
pass_config = click.make_pass_decorator(Config, ensure=True) pass_config = click.make_pass_decorator(Config, ensure=True)
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
currentLoop = asyncio.get_event_loop()
return currentLoop.run_until_complete(f(*args, **kwargs))
return wrapper
def parse_file_size(humanString: str): def parse_file_size(humanString: str):
humanString = humanString.strip() humanString = humanString.strip()
units = {"k": pow(2, 10), "m": pow(2, 20), "g": pow(2, 30)} units = {"k": pow(2, 10), "m": pow(2, 20), "g": pow(2, 30)}
@ -99,13 +110,13 @@ def retrieve_app_hash(app_config):
return app_hash, app_id, proxy return app_hash, app_id, proxy
def check_logged_in(config: Config): async def check_logged_in(config: Config):
config.client.connect() await config.client.connect()
if not config.client.is_user_authorized(): if not await config.client.is_user_authorized():
raise click.ClickException('You are not authorized. Please log in first') raise click.ClickException('You are not authorized. Please log in first')
def upload_file(client: TelegramClient, source: typing.BinaryIO, expectedStreamSize: int, nameHash: str, dialog: str, maxFileSize): async def upload_file(client: TelegramClient, source: typing.BinaryIO, expectedStreamSize: int, nameHash: str, dialog: str, maxFileSize):
fileId = generate_random_long() fileId = generate_random_long()
chunkSize = pow(2, 19) chunkSize = pow(2, 19)
maxChunksInFile = maxFileSize / chunkSize maxChunksInFile = maxFileSize / chunkSize
@ -119,8 +130,18 @@ def upload_file(client: TelegramClient, source: typing.BinaryIO, expectedStreamS
currentFileBytesWritten = 0 currentFileBytesWritten = 0
realBytesWritten = 0 realBytesWritten = 0
nextBuffer = source.read(chunkSize)
if not len(nextBuffer):
raise click.ClickException('Input stream is empty')
currentLoop = asyncio.get_running_loop()
currentExecutor = ThreadPoolExecutor(max_workers=1)
def readNextPortion(readBytes=chunkSize):
return source.read(readBytes)
lastRealTimeMeasurement = time.time()
while True: while True:
buffer = source.read(chunkSize) buffer = nextBuffer
bufLen = len(buffer) bufLen = len(buffer)
if not bufLen: if not bufLen:
break break
@ -128,14 +149,19 @@ def upload_file(client: TelegramClient, source: typing.BinaryIO, expectedStreamS
raise click.ClickException('Stream is larger than expected file size') raise click.ClickException('Stream is larger than expected file size')
if bufLen < chunkSize: if bufLen < chunkSize:
buffer = buffer + bytearray([fillByte] * (chunkSize - bufLen)) buffer = buffer + bytearray([fillByte] * (chunkSize - bufLen))
client(SaveBigFilePartRequest(
taskSend = client(SaveBigFilePartRequest(
fileId, currentFileChunkPos, currentFileTotalChunks, buffer fileId, currentFileChunkPos, currentFileTotalChunks, buffer
)) ))
taskGetNextBuffer = currentLoop.run_in_executor(currentExecutor, readNextPortion, chunkSize)
clientResponse, nextBuffer = await asyncio.gather(taskSend, taskGetNextBuffer, loop=currentLoop, return_exceptions=False)
currentFileChunkPos += 1 currentFileChunkPos += 1
currentFileBytesWritten += bufLen currentFileBytesWritten += bufLen
realBytesWritten += bufLen realBytesWritten += bufLen
if currentFileChunkPos % 10 == 0: if currentFileChunkPos % 10 == 0:
click.echo(f'{format_file_size(realBytesWritten)} bytes sent') click.echo(f'{format_file_size(realBytesWritten)}b {(time.time() - lastRealTimeMeasurement):.3f}s \t ', nl=False)
lastRealTimeMeasurement = time.time()
if bufLen < chunkSize: if bufLen < chunkSize:
break break
@ -144,7 +170,7 @@ def upload_file(client: TelegramClient, source: typing.BinaryIO, expectedStreamS
readyFile = InputFileBig( readyFile = InputFileBig(
fileId, currentFileTotalChunks, f'{int(time.time() * 1000)}.bin' fileId, currentFileTotalChunks, f'{int(time.time() * 1000)}.bin'
) )
client.send_file(dialog, readyFile, caption=textwrap.dedent(f''' await client.send_file(dialog, readyFile, caption=textwrap.dedent(f'''
{MESSAGE_BLOCK_START} {MESSAGE_BLOCK_START}
{MESSAGE_HEADER} {MESSAGE_HEADER}
#telecup_part #telecup_part_{nameHash} #telecup_part #telecup_part_{nameHash}
@ -167,14 +193,14 @@ def upload_file(client: TelegramClient, source: typing.BinaryIO, expectedStreamS
if currentFileChunkPos < currentFileTotalChunks: if currentFileChunkPos < currentFileTotalChunks:
buffer = bytes(bytearray([fillByte] * chunkSize)) buffer = bytes(bytearray([fillByte] * chunkSize))
for newChunkPos in range(currentFileChunkPos, currentFileTotalChunks): for newChunkPos in range(currentFileChunkPos, currentFileTotalChunks):
client(SaveBigFilePartRequest( await client(SaveBigFilePartRequest(
fileId, newChunkPos, currentFileTotalChunks, buffer fileId, newChunkPos, currentFileTotalChunks, buffer
)) ))
readyFile = InputFileBig( readyFile = InputFileBig(
fileId, currentFileTotalChunks, f'{int(time.time() * 1000)}.bin' fileId, currentFileTotalChunks, f'{int(time.time() * 1000)}.bin'
) )
client.send_file(dialog, readyFile, caption=textwrap.dedent(f''' await client.send_file(dialog, readyFile, caption=textwrap.dedent(f'''
{MESSAGE_BLOCK_START} {MESSAGE_BLOCK_START}
{MESSAGE_HEADER} {MESSAGE_HEADER}
#telecup_part #telecup_part_{nameHash} #telecup_part #telecup_part_{nameHash}
@ -187,7 +213,7 @@ def upload_file(client: TelegramClient, source: typing.BinaryIO, expectedStreamS
return UploadInfo(inputFiles, realBytesWritten) return UploadInfo(inputFiles, realBytesWritten)
def download_part(client: TelegramClient, dest: typing.BinaryIO, dInfo: DownloadInfo): async def download_part(client: TelegramClient, dest: typing.BinaryIO, dInfo: DownloadInfo):
dcId, inputFileLocation = get_input_location(dInfo.message) dcId, inputFileLocation = get_input_location(dInfo.message)
chunkSize = pow(2, 20) chunkSize = pow(2, 20)
realSize = int(dInfo.part_info['real_size']) realSize = int(dInfo.part_info['real_size'])
@ -199,7 +225,7 @@ def download_part(client: TelegramClient, dest: typing.BinaryIO, dInfo: Download
extraBytes = pow(2, 12) - limit % pow(2, 12) extraBytes = pow(2, 12) - limit % pow(2, 12)
limit += extraBytes limit += extraBytes
downloadResult: File = client(GetFileRequest( downloadResult: File = await client(GetFileRequest(
inputFileLocation, inputFileLocation,
offset, offset,
limit, limit,
@ -218,10 +244,10 @@ def download_part(client: TelegramClient, dest: typing.BinaryIO, dInfo: Download
click.echo(f"Part {dInfo.part_info['part']}: {format_file_size(totalBytesDownloaded)} ready", err=True) click.echo(f"Part {dInfo.part_info['part']}: {format_file_size(totalBytesDownloaded)} ready", err=True)
def download_file(client: TelegramClient, dest: typing.BinaryIO, fileInfo: dict, dialog: str): async def download_file(client: TelegramClient, dest: typing.BinaryIO, fileInfo: dict, dialog: str):
partMessagesSearchResults = client.iter_messages(dialog, search=f"#telecup_part_{fileInfo['name_hash']}") partMessagesSearchResults = client.iter_messages(dialog, search=f"#telecup_part_{fileInfo['name_hash']}")
partMessages: typing.List[typing.Optional[DownloadInfo]] = [None] * int(fileInfo['total_parts']) partMessages: typing.List[typing.Optional[DownloadInfo]] = [None] * int(fileInfo['total_parts'])
for msg in partMessagesSearchResults: async for msg in partMessagesSearchResults:
partInfo = parse_message(msg.message) partInfo = parse_message(msg.message)
if not partInfo or partInfo['name_hash'] != fileInfo['name_hash']: if not partInfo or partInfo['name_hash'] != fileInfo['name_hash']:
continue continue
@ -230,7 +256,7 @@ def download_file(client: TelegramClient, dest: typing.BinaryIO, fileInfo: dict,
raise click.ClickException('Missing some parts') raise click.ClickException('Missing some parts')
partMessages.sort(key=lambda dInfo: int(dInfo.part_info['part'])) partMessages.sort(key=lambda dInfo: int(dInfo.part_info['part']))
for part in partMessages: for part in partMessages:
download_part(client, dest, part) await download_part(client, dest, part)
click.echo(f"Part {part.part_info['part']} ready", err=True) click.echo(f"Part {part.part_info['part']} ready", err=True)
@ -240,7 +266,8 @@ def download_file(client: TelegramClient, dest: typing.BinaryIO, fileInfo: dict,
@click.option('--app-config', default='.telecup.cfg', show_default=True, help='App configuration file') @click.option('--app-config', default='.telecup.cfg', show_default=True, help='App configuration file')
@click.option('--dialog', default='me', help='Name of the conversation, where files will be stored [default: Saved Messages]') @click.option('--dialog', default='me', help='Name of the conversation, where files will be stored [default: Saved Messages]')
@pass_config @pass_config
def cli(config: Config, session_file, app_config, dialog): @coro
async def cli(config: Config, session_file, app_config, dialog):
config.session_file = session_file config.session_file = session_file
config.dialog = dialog config.dialog = dialog
@ -257,24 +284,26 @@ def cli(config: Config, session_file, app_config, dialog):
@cli.command() @cli.command()
@pass_config @pass_config
def login(config: Config): @coro
async def login(config: Config):
""" """
Log into your telegram account interactively and save the login information to session file Log into your telegram account interactively and save the login information to session file
""" """
config.client.start() await config.client.start()
click.echo(f'You are logged in as {config.client.get_me().first_name}') click.echo(f'You are logged in as {await config.client.get_me().first_name}')
click.echo('To switch user: either pass another session-file or remove existing') click.echo('To switch user: either pass another session-file or remove existing')
@cli.command(name='list') @cli.command(name='list')
@pass_config @pass_config
def list_files(config: Config): @coro
async def list_files(config: Config):
""" """
List all files, uploaded to your account by TeleCup List all files, uploaded to your account by TeleCup
""" """
check_logged_in(config) await check_logged_in(config)
fileMessages: typing.Iterator[Message] = config.client.iter_messages(config.dialog, search='#telecup_file') fileMessages = config.client.iter_messages(config.dialog, search='#telecup_file')
for msg in fileMessages: async for msg in fileMessages:
fileInfo = parse_message(msg.message) fileInfo = parse_message(msg.message)
if not fileInfo: if not fileInfo:
continue continue
@ -288,7 +317,8 @@ def list_files(config: Config):
@click.argument('filename') @click.argument('filename')
@click.argument('stream_size') @click.argument('stream_size')
@pass_config @pass_config
def upload(config: Config, part_size, filename, stream_size): @coro
async def upload(config: Config, part_size, filename, stream_size):
""" """
Upload a new file to your account Upload a new file to your account
""" """
@ -297,17 +327,15 @@ def upload(config: Config, part_size, filename, stream_size):
partSize = parse_file_size(part_size) partSize = parse_file_size(part_size)
if partSize > pow(2, 20) * 1536: if partSize > pow(2, 20) * 1536:
raise click.ClickException('Part size must be less than 1.5 Gib') raise click.ClickException('Part size must be less than 1.5 Gib')
check_logged_in(config) await check_logged_in(config)
nameHash = hashlib.sha256(filename.encode('utf-8')).hexdigest() nameHash = hashlib.sha256(filename.encode('utf-8')).hexdigest()
existingMessages = config.client.get_messages(config.dialog, search=f'#telecup_file_{nameHash}') existingMessages = await config.client.get_messages(config.dialog, search=f'#telecup_file_{nameHash}')
if len(existingMessages) > 0: if len(existingMessages) > 0:
raise click.ClickException(f'File with name {filename} already exists in dialog {config.dialog}') raise click.ClickException(f'File with name {filename} already exists in dialog {config.dialog}')
uploadInfo = upload_file(config.client, source, estimatedBytes, nameHash, config.dialog, partSize) uploadInfo = await upload_file(config.client, source, estimatedBytes, nameHash, config.dialog, partSize)
if uploadInfo.real_size == 0: await config.client.send_message(config.dialog, textwrap.dedent(f'''
raise click.ClickException('Input stream is empty')
config.client.send_message(config.dialog, textwrap.dedent(f'''
{MESSAGE_BLOCK_START} {MESSAGE_BLOCK_START}
{MESSAGE_HEADER} {MESSAGE_HEADER}
#telecup_file #telecup_file_{nameHash} #telecup_file #telecup_file_{nameHash}
@ -319,26 +347,27 @@ def upload(config: Config, part_size, filename, stream_size):
{MESSAGE_BLOCK_END} {MESSAGE_BLOCK_END}
''')) '''))
click.echo('OK') click.echo('OK')
config.client.disconnect() await config.client.disconnect()
@cli.command() @cli.command()
@click.argument('filename') @click.argument('filename')
@pass_config @pass_config
def download(config: Config, filename): @coro
async def download(config: Config, filename):
""" """
Download an existing file from your account Download an existing file from your account
""" """
destination = sys.stdout.buffer destination = sys.stdout.buffer
check_logged_in(config) await check_logged_in(config)
fileMessages: typing.Iterator[Message] = config.client.iter_messages(config.dialog, search='#telecup_file') fileMessages = config.client.iter_messages(config.dialog, search='#telecup_file')
for msg in fileMessages: async for msg in fileMessages:
fileInfo = parse_message(msg.message) fileInfo = parse_message(msg.message)
if fileInfo and ( if fileInfo and (
fileInfo['name'] == filename.strip() or fileInfo['name'] == filename.strip() or
fileInfo['name_hash'].startswith(filename.strip()) fileInfo['name_hash'].startswith(filename.strip())
): ):
download_file(config.client, destination, fileInfo, config.dialog) await download_file(config.client, destination, fileInfo, config.dialog)
click.echo('OK', err=True) click.echo('OK', err=True)
config.client.disconnect() await config.client.disconnect()