Skip to content

Cleanup parser & url classes #843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 1, 2021
Merged
3 changes: 1 addition & 2 deletions examples/https_connect_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from proxy import Proxy
from proxy.common.utils import build_http_response
from proxy.http import httpStatusCodes
from proxy.http.parser import httpParserStates
from proxy.core.base import BaseTcpTunnelHandler


Expand Down Expand Up @@ -58,7 +57,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:

# CONNECT requests are short and we need not worry about
# receiving partial request bodies here.
assert self.request.state == httpParserStates.COMPLETE
assert self.request.is_complete

# Establish connection with upstream
self.connect_upstream()
Expand Down
2 changes: 2 additions & 0 deletions proxy/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def _env_threadless_compliant() -> bool:
SLASH = b'/'
HTTP_1_0 = b'HTTP/1.0'
HTTP_1_1 = b'HTTP/1.1'
HTTP_URL_PREFIX = b'http://'
HTTPS_URL_PREFIX = b'https://'

PROXY_AGENT_HEADER_KEY = b'Proxy-agent'
PROXY_AGENT_HEADER_VALUE = b'proxy.py v' + \
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
# TODO(abhinavsingh): Remove .tobytes after parser is
# memoryview compliant
self.request.parse(data.tobytes())
if self.request.state == httpParserStates.COMPLETE:
if self.request.is_complete:
# Invoke plugin.on_request_complete
for plugin in self.plugins.values():
upgraded_sock = plugin.on_request_complete()
Expand Down
136 changes: 73 additions & 63 deletions proxy/http/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...common.constants import DEFAULT_DISABLE_HEADERS, COLON, DEFAULT_ENABLE_PROXY_PROTOCOL
from ...common.constants import HTTP_1_1, SLASH, CRLF
from ...common.constants import WHITESPACE, DEFAULT_HTTP_PORT
from ...common.utils import build_http_request, build_http_response, find_http_line, text_
from ...common.utils import build_http_request, build_http_response, text_
from ...common.flag import flags

from ..url import Url
Expand Down Expand Up @@ -63,10 +63,12 @@ def __init__(
if enable_proxy_protocol:
assert self.type == httpParserTypes.REQUEST_PARSER
self.protocol = ProxyProtocol()
# Request attributes
self.host: Optional[bytes] = None
self.port: Optional[int] = None
self.path: Optional[bytes] = None
self.method: Optional[bytes] = None
# Response attributes
self.code: Optional[bytes] = None
self.reason: Optional[bytes] = None
self.version: Optional[bytes] = None
Expand All @@ -78,7 +80,7 @@ def __init__(
# - Keys are lower case header names.
# - Values are 2-tuple containing original
# header and it's value as received.
self.headers: Dict[bytes, Tuple[bytes, bytes]] = {}
self.headers: Optional[Dict[bytes, Tuple[bytes, bytes]]] = None
self.body: Optional[bytes] = None
self.chunk: Optional[ChunkParser] = None
# Internal request line as a url structure
Expand Down Expand Up @@ -109,19 +111,24 @@ def response(cls: Type[T], raw: bytes) -> T:

def header(self, key: bytes) -> bytes:
"""Convenient method to return original header value from internal data structure."""
if key.lower() not in self.headers:
if self.headers is None or key.lower() not in self.headers:
raise KeyError('%s not found in headers', text_(key))
return self.headers[key.lower()][1]

def has_header(self, key: bytes) -> bool:
"""Returns true if header key was found in payload."""
if self.headers is None:
return False
return key.lower() in self.headers

def add_header(self, key: bytes, value: bytes) -> bytes:
"""Add/Update a header to internal data structure.

Returns key with which passed (key, value) tuple is available."""
if self.headers is None:
self.headers = {}
k = key.lower()
# k = key
self.headers[k] = (key, value)
return k

Expand All @@ -132,7 +139,7 @@ def add_headers(self, headers: List[Tuple[bytes, bytes]]) -> None:

def del_header(self, header: bytes) -> None:
"""Delete a header from internal data structure."""
if header.lower() in self.headers:
if self.headers and header.lower() in self.headers:
del self.headers[header.lower()]

def del_headers(self, headers: List[bytes]) -> None:
Expand All @@ -151,6 +158,10 @@ def has_host(self) -> bool:
NOTE: Host field WILL be None for incoming local WebServer requests."""
return self.host is not None

@property
def is_complete(self) -> bool:
return self.state == httpParserStates.COMPLETE

@property
def is_http_1_1_keep_alive(self) -> bool:
"""Returns true for HTTP/1.1 keep-alive connections."""
Expand Down Expand Up @@ -185,30 +196,34 @@ def content_expected(self) -> bool:
@property
def body_expected(self) -> bool:
"""Returns true if content or chunked response is expected."""
return self.content_expected or self.is_chunked_encoded
return self._content_expected or self._is_chunked_encoded

def parse(self, raw: bytes) -> None:
"""Parses HTTP request out of raw bytes.

Check for `HttpParser.state` after `parse` has successfully returned."""
self.total_size += len(raw)
size = len(raw)
self.total_size += size
raw = self.buffer + raw
self.buffer, more = b'', len(raw) > 0
self.buffer, more = b'', size > 0
while more and self.state != httpParserStates.COMPLETE:
# gte with HEADERS_COMPLETE also encapsulated RCVING_BODY state
more, raw = self._process_body(raw) \
if self.state >= httpParserStates.HEADERS_COMPLETE else \
self._process_line_and_headers(raw)
if self.state >= httpParserStates.HEADERS_COMPLETE:
more, raw = self._process_body(raw)
elif self.state == httpParserStates.INITIALIZED:
more, raw = self._process_line(raw)
else:
more, raw = self._process_headers(raw)
# When server sends a response line without any header or body e.g.
# HTTP/1.1 200 Connection established\r\n\r\n
if self.state == httpParserStates.LINE_RCVD and \
raw == CRLF and \
self.type == httpParserTypes.RESPONSE_PARSER:
if self.type == httpParserTypes.RESPONSE_PARSER and \
self.state == httpParserStates.LINE_RCVD and \
raw == CRLF:
self.state = httpParserStates.COMPLETE
# Mark request as complete if headers received and no incoming
# body indication received.
elif self.state == httpParserStates.HEADERS_COMPLETE and \
not self.body_expected and \
not (self._content_expected or self._is_chunked_encoded) and \
raw == b'':
self.state = httpParserStates.COMPLETE
self.buffer = raw
Expand All @@ -229,7 +244,7 @@ def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool =
COLON +
str(self.port).encode() +
path
) if not self.is_https_tunnel else (self.host + COLON + str(self.port).encode())
) if not self._is_https_tunnel else (self.host + COLON + str(self.port).encode())
return build_http_request(
self.method, path, self.version,
headers={} if not self.headers else {
Expand Down Expand Up @@ -263,15 +278,15 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]:
# the latter MUST be ignored.
#
# TL;DR -- Give transfer-encoding header preference over content-length.
if self.is_chunked_encoded:
if self._is_chunked_encoded:
if not self.chunk:
self.chunk = ChunkParser()
raw = self.chunk.parse(raw)
if self.chunk.state == chunkParserStates.COMPLETE:
self.body = self.chunk.body
self.state = httpParserStates.COMPLETE
more = False
elif self.content_expected:
elif self._content_expected:
self.state = httpParserStates.RCVING_BODY
if self.body is None:
self.body = b''
Expand All @@ -297,7 +312,7 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]:
more, raw = False, b''
return more, raw

def _process_line_and_headers(self, raw: bytes) -> Tuple[bool, bytes]:
def _process_headers(self, raw: bytes) -> Tuple[bool, bytes]:
"""Returns False when no CRLF could be found in received bytes.

TODO: We should not return until parser reaches headers complete
Expand All @@ -308,60 +323,59 @@ def _process_line_and_headers(self, raw: bytes) -> Tuple[bool, bytes]:
This will also help make the parser even more stateless.
"""
while True:
line, raw = find_http_line(raw)
if line is None:
parts = raw.split(CRLF, 1)
if len(parts) == 1:
return False, raw

if self.state == httpParserStates.INITIALIZED:
self._process_line(line)
if self.state == httpParserStates.INITIALIZED:
# return len(raw) > 0, raw
continue
elif self.state in (httpParserStates.LINE_RCVD, httpParserStates.RCVING_HEADERS):
if self.state == httpParserStates.LINE_RCVD:
self.state = httpParserStates.RCVING_HEADERS
line, raw = parts[0], parts[1]
if self.state in (httpParserStates.LINE_RCVD, httpParserStates.RCVING_HEADERS):
if line == b'' or line.strip() == b'': # Blank line received.
self.state = httpParserStates.HEADERS_COMPLETE
else:
self.state = httpParserStates.RCVING_HEADERS
self._process_header(line)

# If raw length is now zero, bail out
# If we have received all headers, bail out
if raw == b'' or self.state == httpParserStates.HEADERS_COMPLETE:
break
return len(raw) > 0, raw

def _process_line(self, raw: bytes) -> None:
if self.type == httpParserTypes.REQUEST_PARSER:
if self.protocol is not None and self.protocol.version is None:
# We expect to receive entire proxy protocol v1 line
# in one network read and don't expect partial packets
self.protocol.parse(raw)
else:
def _process_line(self, raw: bytes) -> Tuple[bool, bytes]:
while True:
parts = raw.split(CRLF, 1)
if len(parts) == 1:
return False, raw
line, raw = parts[0], parts[1]
if self.type == httpParserTypes.REQUEST_PARSER:
if self.protocol is not None and self.protocol.version is None:
# We expect to receive entire proxy protocol v1 line
# in one network read and don't expect partial packets
self.protocol.parse(line)
continue
# Ref: https://datatracker.ietf.org/doc/html/rfc2616#section-5.1
line = raw.split(WHITESPACE, 2)
if len(line) == 3:
self.method = line[0].upper()
parts = line.split(WHITESPACE, 2)
if len(parts) == 3:
self.method = parts[0]
if self.method == httpMethods.CONNECT:
self._is_https_tunnel = True
self.set_url(line[1])
self.version = line[2]
self.set_url(parts[1])
self.version = parts[2]
self.state = httpParserStates.LINE_RCVD
else:
# To avoid a possible attack vector, we raise exception
# if parser receives an invalid request line.
#
# TODO: Better to use raise HttpProtocolException,
# but we should solve circular import problem first.
raise ValueError('Invalid request line')
else:
line = raw.split(WHITESPACE, 2)
self.version = line[0]
self.code = line[1]
break
# To avoid a possible attack vector, we raise exception
# if parser receives an invalid request line.
#
# TODO: Better to use raise HttpProtocolException,
# but we should solve circular import problem first.
raise ValueError('Invalid request line')
parts = line.split(WHITESPACE, 2)
self.version = parts[0]
self.code = parts[1]
# Our own WebServerPlugin example currently doesn't send any reason
if len(line) == 3:
self.reason = line[2]
if len(parts) == 3:
self.reason = parts[2]
self.state = httpParserStates.LINE_RCVD
break
return len(raw) > 0, raw

def _process_header(self, raw: bytes) -> None:
parts = raw.split(COLON, 1)
Expand All @@ -380,20 +394,16 @@ def _process_header(self, raw: bytes) -> None:

def _get_body_or_chunks(self) -> Optional[bytes]:
return ChunkParser.to_chunks(self.body) \
if self.body and self.is_chunked_encoded else \
if self.body and self._is_chunked_encoded else \
self.body

def _set_line_attributes(self) -> None:
if self.type == httpParserTypes.REQUEST_PARSER:
if self.is_https_tunnel and self._url:
assert self._url
if self._is_https_tunnel:
self.host = self._url.hostname
self.port = 443 if self._url.port is None else self._url.port
elif self._url:
else:
self.host, self.port = self._url.hostname, self._url.port \
if self._url.port else DEFAULT_HTTP_PORT
else:
raise KeyError(
'Invalid request. Method: %r, Url: %r' %
(self.method, self._url),
)
self.path = self._url.remainder
2 changes: 1 addition & 1 deletion proxy/http/proxy/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class AuthPlugin(HttpProxyBasePlugin):
def before_upstream_connection(
self, request: HttpParser,
) -> Optional[HttpParser]:
if self.flags.auth_code:
if self.flags.auth_code and request.headers:
if b'proxy-authorization' not in request.headers:
raise ProxyAuthenticationFailed()
parts = request.headers[b'proxy-authorization'][1].split()
Expand Down
Loading