# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Internal network layer helper methods.""" from __future__ import annotations import asyncio import collections import errno import socket import struct import sys import time from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport from typing import ( TYPE_CHECKING, Any, Optional, Union, ) from pymongo import _csot, ssl_support from pymongo._asyncio_task import create_task from pymongo.common import MAX_MESSAGE_SIZE from pymongo.compression_support import decompress from pymongo.errors import ProtocolError, _OperationCancelled from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.socket_checker import _errno_from_exception try: from ssl import SSLError, SSLSocket _HAVE_SSL = True except ImportError: _HAVE_SSL = False try: from pymongo.pyopenssl_context import _sslConn _HAVE_PYOPENSSL = True except ImportError: _HAVE_PYOPENSSL = False _sslConn = SSLSocket # type: ignore[assignment, misc] from pymongo.ssl_support import ( BLOCKING_IO_LOOKUP_ERROR, BLOCKING_IO_READ_ERROR, BLOCKING_IO_WRITE_ERROR, ) if TYPE_CHECKING: from pymongo.asynchronous.pool import AsyncConnection from pymongo.synchronous.pool import Connection _UNPACK_HEADER = struct.Struct(" None: timeout = sock.gettimeout() sock.settimeout(0.0) loop = asyncio.get_running_loop() try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout) else: await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type] except asyncio.TimeoutError as exc: # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. raise socket.timeout("timed out") from exc finally: sock.settimeout(timeout) if sys.platform != "win32": async def _async_socket_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop ) -> None: view = memoryview(buf) sent = 0 def _is_ready(fut: Future) -> None: if fut.done(): return fut.set_result(None) while sent < len(buf): try: sent += sock.send(view[sent:]) except BLOCKING_IO_ERRORS as exc: fd = sock.fileno() # Check for closed socket. if fd == -1: raise SSLError("Underlying socket has been closed") from None if isinstance(exc, BLOCKING_IO_READ_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) try: await fut finally: loop.remove_reader(fd) if isinstance(exc, BLOCKING_IO_WRITE_ERROR): fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) try: await fut finally: loop.remove_writer(fd) if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) try: loop.add_writer(fd, _is_ready, fut) await fut finally: loop.remove_reader(fd) loop.remove_writer(fd) async def _async_socket_receive_ssl( conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 def _is_ready(fut: Future) -> None: if fut.done(): return fut.set_result(None) while total_read < length: try: read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") # KMS responses update their expected size after the first batch, stop reading after one loop if once: return mv[:read] total_read += read except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() # Check for closed socket. if fd == -1: raise SSLError("Underlying socket has been closed") from None if isinstance(exc, BLOCKING_IO_READ_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) try: await fut finally: loop.remove_reader(fd) if isinstance(exc, BLOCKING_IO_WRITE_ERROR): fut = loop.create_future() loop.add_writer(fd, _is_ready, fut) try: await fut finally: loop.remove_writer(fd) if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): fut = loop.create_future() loop.add_reader(fd, _is_ready, fut) try: loop.add_writer(fd, _is_ready, fut) await fut finally: loop.remove_reader(fd) loop.remove_writer(fd) return mv else: # The default Windows asyncio event loop does not support loop.add_reader/add_writer: # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. async def _async_socket_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop ) -> None: view = memoryview(buf) total_length = len(buf) total_sent = 0 # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success # down to 1ms. backoff = 0.001 while total_sent < total_length: try: sent = sock.send(view[total_sent:]) except BLOCKING_IO_ERRORS: await asyncio.sleep(backoff) sent = 0 if sent > 0: backoff = max(backoff / 2, 0.001) else: backoff = min(backoff * 2, 0.512) total_sent += sent async def _async_socket_receive_ssl( conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success # down to 1ms. backoff = 0.001 while total_read < length: try: read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") # KMS responses update their expected size after the first batch, stop reading after one loop if once: return mv[:read] except BLOCKING_IO_ERRORS: await asyncio.sleep(backoff) read = 0 if read > 0: backoff = max(backoff / 2, 0.001) else: backoff = min(backoff * 2, 0.512) total_read += read return mv def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) async def _poll_cancellation(conn: AsyncConnection) -> None: while True: if conn.cancel_context.cancelled: return await asyncio.sleep(_POLL_TIMEOUT) async def async_receive_data_socket( sock: Union[socket.socket, _sslConn], length: int ) -> memoryview: sock_timeout = sock.gettimeout() timeout = sock_timeout sock.settimeout(0.0) loop = asyncio.get_running_loop() try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): return await asyncio.wait_for( _async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] timeout=timeout, ) else: return await asyncio.wait_for( _async_socket_receive(sock, length, loop), # type: ignore[arg-type] timeout=timeout, ) except asyncio.TimeoutError as err: raise socket.timeout("timed out") from err finally: sock.settimeout(sock_timeout) async def _async_socket_receive( conn: socket.socket, length: int, loop: AbstractEventLoop ) -> memoryview: mv = memoryview(bytearray(length)) bytes_read = 0 while bytes_read < length: chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:]) if chunk_length == 0: raise OSError("connection closed") bytes_read += chunk_length return mv _PYPY = "PyPy" in sys.version _WINDOWS = sys.platform == "win32" def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" sock = conn.conn.sock timed_out = False # Check if the connection's socket has been manually closed if sock.fileno() == -1: return while True: # SSLSocket can have buffered data which won't be caught by select. if hasattr(sock, "pending") and sock.pending() > 0: readable = True else: # Wait up to 500ms for the socket to become readable and then # check for cancellation. if deadline: remaining = deadline - time.monotonic() # When the timeout has expired perform one final check to # see if the socket is readable. This helps avoid spurious # timeouts on AWS Lambda and other FaaS environments. if remaining <= 0: timed_out = True timeout = max(min(remaining, _POLL_TIMEOUT), 0) else: timeout = _POLL_TIMEOUT readable = conn.socket_checker.select(sock, read=True, timeout=timeout) if conn.cancel_context.cancelled: raise _OperationCancelled("operation cancelled") if readable: return if timed_out: raise socket.timeout("timed out") def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 # To support cancelling a network read, we shorten the socket timeout and # check for the cancellation signal after each timeout. Alternatively we # could close the socket but that does not reliably cancel recv() calls # on all OSes. # When the timeout has expired we perform one final non-blocking recv. # This helps avoid spurious timeouts when the response is actually already # buffered on the client. orig_timeout = conn.conn.gettimeout() try: while bytes_read < length: try: # Use the legacy wait_for_read cancellation approach on PyPy due to PYTHON-5011. # also use it on Windows due to PYTHON-5405 if _PYPY or _WINDOWS: wait_for_read(conn, deadline) if _csot.get_timeout() and deadline is not None: conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) else: if deadline is not None: short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT) else: short_timeout = _POLL_TIMEOUT conn.set_conn_timeout(short_timeout) chunk_length = conn.conn.recv_into(mv[bytes_read:]) except BLOCKING_IO_ERRORS: if conn.cancel_context.cancelled: raise _OperationCancelled("operation cancelled") from None # We reached the true deadline. raise socket.timeout("timed out") from None except socket.timeout: if conn.cancel_context.cancelled: raise _OperationCancelled("operation cancelled") from None if ( _PYPY or _WINDOWS or not conn.is_sdam and deadline is not None and deadline - time.monotonic() < 0 ): # We reached the true deadline. raise continue except OSError as exc: if conn.cancel_context.cancelled: raise _OperationCancelled("operation cancelled") from None if _errno_from_exception(exc) == errno.EINTR: continue raise if chunk_length == 0: raise OSError("connection closed") bytes_read += chunk_length finally: conn.set_conn_timeout(orig_timeout) return mv class NetworkingInterfaceBase: def __init__(self, conn: Any): self.conn = conn @property def gettimeout(self) -> Any: raise NotImplementedError def settimeout(self, timeout: float | None) -> None: raise NotImplementedError def close(self) -> Any: raise NotImplementedError def is_closing(self) -> bool: raise NotImplementedError @property def get_conn(self) -> Any: raise NotImplementedError @property def sock(self) -> Any: raise NotImplementedError class AsyncNetworkingInterface(NetworkingInterfaceBase): def __init__(self, conn: tuple[Transport, PyMongoProtocol]): super().__init__(conn) @property def gettimeout(self) -> float | None: return self.conn[1].gettimeout def settimeout(self, timeout: float | None) -> None: self.conn[1].settimeout(timeout) async def close(self) -> None: self.conn[1].close() await self.conn[1].wait_closed() def is_closing(self) -> bool: return self.conn[0].is_closing() @property def get_conn(self) -> PyMongoProtocol: return self.conn[1] @property def sock(self) -> socket.socket: return self.conn[0].get_extra_info("socket") class NetworkingInterface(NetworkingInterfaceBase): def __init__(self, conn: Union[socket.socket, _sslConn]): super().__init__(conn) def gettimeout(self) -> float | None: return self.conn.gettimeout() def settimeout(self, timeout: float | None) -> None: self.conn.settimeout(timeout) def close(self) -> None: self.conn.close() def is_closing(self) -> bool: return self.conn.is_closing() @property def get_conn(self) -> Union[socket.socket, _sslConn]: return self.conn @property def sock(self) -> Union[socket.socket, _sslConn]: return self.conn def fileno(self) -> int: return self.conn.fileno() def recv_into(self, buffer: bytes) -> int: return self.conn.recv_into(buffer) class PyMongoProtocol(BufferedProtocol): def __init__(self, timeout: Optional[float] = None): self.transport: Transport = None # type: ignore[assignment] # Each message is reader in 2-3 parts: header, compression header, and message body # The message buffer is allocated after the header is read. self._header = memoryview(bytearray(16)) self._header_index = 0 self._compression_header = memoryview(bytearray(9)) self._compression_index = 0 self._message: Optional[memoryview] = None self._message_index = 0 # State. TODO: replace booleans with an enum? self._expecting_header = True self._expecting_compression = False self._message_size = 0 self._op_code = 0 self._connection_lost = False self._read_waiter: Optional[Future] = None self._timeout = timeout self._is_compressed = False self._compressor_id: Optional[int] = None self._max_message_size = MAX_MESSAGE_SIZE self._response_to: Optional[int] = None self._closed = asyncio.get_running_loop().create_future() self._pending_messages: collections.deque[Future] = collections.deque() self._done_messages: collections.deque[Future] = collections.deque() def settimeout(self, timeout: float | None) -> None: self._timeout = timeout @property def gettimeout(self) -> float | None: """The configured timeout for the socket that underlies our protocol pair.""" return self._timeout def connection_made(self, transport: BaseTransport) -> None: """Called exactly once when a connection is made. The transport argument is the transport representing the write side of the connection. """ self.transport = transport # type: ignore[assignment] self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE) async def write(self, message: bytes) -> None: """Write a message to this connection's transport.""" if self.transport.is_closing(): raise OSError("Connection is closed") self.transport.write(message) self.transport.resume_reading() async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]: """Read a single MongoDB Wire Protocol message from this connection.""" if self.transport: try: self.transport.resume_reading() # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322 except AttributeError: raise OSError("connection is already closed") from None self._max_message_size = max_message_size if self._done_messages: message = await self._done_messages.popleft() else: if self.transport and self.transport.is_closing(): raise OSError("connection is already closed") read_waiter = asyncio.get_running_loop().create_future() self._pending_messages.append(read_waiter) try: message = await read_waiter finally: if read_waiter in self._done_messages: self._done_messages.remove(read_waiter) if message: op_code, compressor_id, response_to, data = message # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: raise ProtocolError( f"Got response id {response_to!r} but expected {request_id!r}" ) if compressor_id is not None: data = decompress(data, compressor_id) return data, op_code raise OSError("connection closed") def get_buffer(self, sizehint: int) -> memoryview: """Called to allocate a new receive buffer. The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data. If any data does not fit into the returned buffer, this method will be called again until either no data remains or an empty buffer is returned. """ # Due to a bug, Python <=3.11 will call get_buffer() even after we raise # ProtocolError in buffer_updated() and call connection_lost(). We allocate # a temp buffer to drain the waiting data. if self._connection_lost: if not self._message: self._message = memoryview(bytearray(2**14)) return self._message # TODO: optimize this by caching pointers to the buffers. # return self._buffer[self._index:] if self._expecting_header: return self._header[self._header_index :] if self._expecting_compression: return self._compression_header[self._compression_index :] return self._message[self._message_index :] # type: ignore[index] def buffer_updated(self, nbytes: int) -> None: """Called when the buffer was updated with the received data""" # Wrote 0 bytes into a non-empty buffer, signal connection closed if nbytes == 0: self.close(OSError("connection closed")) return if self._connection_lost: return if self._expecting_header: self._header_index += nbytes if self._header_index >= 16: self._expecting_header = False try: ( self._message_size, self._op_code, self._response_to, self._expecting_compression, ) = self.process_header() except ProtocolError as exc: self.close(exc) return self._message = memoryview(bytearray(self._message_size)) return if self._expecting_compression: self._compression_index += nbytes if self._compression_index >= 9: self._expecting_compression = False self._op_code, self._compressor_id = self.process_compression_header() return self._message_index += nbytes if self._message_index >= self._message_size: self._expecting_header = True # Pause reading to avoid storing an arbitrary number of messages in memory. self.transport.pause_reading() if self._pending_messages: result = self._pending_messages.popleft() else: result = asyncio.get_running_loop().create_future() # Future has been cancelled, close this connection if result.done(): self.close(None) return # Necessary values to reconstruct and verify message result.set_result( (self._op_code, self._compressor_id, self._response_to, self._message) ) self._done_messages.append(result) # Reset internal state to expect a new message self._header_index = 0 self._compression_index = 0 self._message_index = 0 self._message_size = 0 self._message = None self._op_code = 0 self._compressor_id = None self._response_to = None def process_header(self) -> tuple[int, int, int, bool]: """Unpack a MongoDB Wire Protocol header.""" length, _, response_to, op_code = _UNPACK_HEADER(self._header) expecting_compression = False if op_code == 2012: # OP_COMPRESSED if length <= 25: raise ProtocolError( f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)" ) expecting_compression = True length -= 9 if length <= 16: raise ProtocolError( f"Message length ({length!r}) not longer than standard message header size (16)" ) if length > self._max_message_size: raise ProtocolError( f"Message length ({length!r}) is larger than server max " f"message size ({self._max_message_size!r})" ) return length - 16, op_code, response_to, expecting_compression def process_compression_header(self) -> tuple[int, int]: """Unpack a MongoDB Wire Protocol compression header.""" op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header) return op_code, compressor_id def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None: pending = list(self._pending_messages) for msg in pending: if not msg.done(): if exc is None: msg.set_result(None) else: msg.set_exception(exc) self._done_messages.append(msg) def close(self, exc: Optional[Exception] = None) -> None: self.transport.abort() self._resolve_pending_messages(exc) self._connection_lost = True def connection_lost(self, exc: Optional[Exception] = None) -> None: self._resolve_pending_messages(exc) if not self._closed.done(): self._closed.set_result(None) async def wait_closed(self) -> None: await self._closed async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None: try: await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout) except asyncio.TimeoutError as exc: # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. raise socket.timeout("timed out") from exc async def async_receive_message( conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE, ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" timeout: Optional[Union[float, int]] timeout = conn.conn.gettimeout if _csot.get_timeout(): deadline = _csot.get_deadline() else: if timeout: deadline = time.monotonic() + timeout else: deadline = None if deadline: # When the timeout has expired perform one final check to # see if the socket is readable. This helps avoid spurious # timeouts on AWS Lambda and other FaaS environments. timeout = max(deadline - time.monotonic(), 0) cancellation_task = create_task(_poll_cancellation(conn)) read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size)) tasks = [read_task, cancellation_task] try: done, pending = await asyncio.wait( tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() if pending: await asyncio.wait(pending) if len(done) == 0: raise socket.timeout("timed out") if read_task in done: data, op_code = read_task.result() try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: raise ProtocolError( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data) raise _OperationCancelled("operation cancelled") except asyncio.CancelledError: for task in tasks: task.cancel() await asyncio.wait(tasks) raise def receive_message( conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" if _csot.get_timeout(): deadline = _csot.get_deadline() else: timeout = conn.conn.gettimeout() if timeout: deadline = time.monotonic() + timeout else: deadline = None # Ignore the response's request id. length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") if length <= 16: raise ProtocolError( f"Message length ({length!r}) not longer than standard message header size (16)" ) if length > max_message_size: raise ProtocolError( f"Message length ({length!r}) is larger than server max " f"message size ({max_message_size!r})" ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) data = decompress(receive_data(conn, length - 25, deadline), compressor_id) else: data = receive_data(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: raise ProtocolError( f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" ) from None return unpack_reply(data)