]> git.the-white-hart.net Git - gemini/cbs-server.git/commitdiff
Refactor main
authorrs <>
Mon, 22 Dec 2025 19:29:25 +0000 (13:29 -0600)
committerrs <>
Mon, 22 Dec 2025 19:29:25 +0000 (13:29 -0600)
cbs-srv.py

index 1940529125a0dfb5af63010e115f8d576a4c4d86..5aae7764903b7c70fa1cca4c8c9253d0fcf4b82e 100755 (executable)
@@ -4,13 +4,14 @@ import select
 import socket
 from OpenSSL import SSL
 from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
 import socket
 from OpenSSL import SSL
 from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
-from urllib.parse import urlparse, unquote
+from urllib.parse import urlsplit, urlunsplit, unquote, ParseResult
 
 import re
 from os import path, environ
 import subprocess
 import mimetypes
 
 
 import re
 from os import path, environ
 import subprocess
 import mimetypes
 
+import traceback
 import logging
 import yaml
 import time
 import logging
 import yaml
 import time
@@ -24,6 +25,7 @@ mimetypes.add_type('text/gemini', '.gemini')
 
 
 def accept_client_cert(conn, cert, err_num, err_depth, ret_code):
 
 
 def accept_client_cert(conn, cert, err_num, err_depth, ret_code):
+    # TODO: validate cert format, dates, signature, etc.
     return True
 
 
     return True
 
 
@@ -34,87 +36,96 @@ class CBSException(Exception):
         self.logdata = logdata
 
 
         self.logdata = logdata
 
 
-def recv_req(conn: SSL.Connection, timeout=.5):
+def recv_request(conn: SSL.Connection, timeout=.5) -> bytes:
     data = b''
     start = time.time()
     while True:
     data = b''
     start = time.time()
     while True:
-        # This prevents "slow loris" types of timeouts
-        if time.time() > start + timeout:
-            raise CBSException(59, 'Timeout while waiting for URL')
+        if time.time() > start + timeout:  # Slow loris timeout
+            raise CBSException(59, 'Timeout while waiting for request')
         ready = select.select([conn], [], [], timeout)
         if ready[0]:
             data += conn.recv(1024)
             if b'\r\n' in data:
                 lines = data.splitlines()
         ready = select.select([conn], [], [], timeout)
         if ready[0]:
             data += conn.recv(1024)
             if b'\r\n' in data:
                 lines = data.splitlines()
-                if len(lines) > 1:
-                    logging.warning(f'Discarding data after URL line of request: {data}')
-                if len(lines[0]) > 1024:
-                    raise CBSException(59, 'URL too long', lines[0])
-                try:
-                    req = lines[0].decode('ascii')
-                except UnicodeDecodeError:
-                    raise CBSException(59, 'Non-ascii URL', lines[0])
-                return req
+                return lines[0]
         else:
         else:
-            raise CBSException(59, 'Timeout while waiting for URL')
+            raise CBSException(59, 'Timeout while waiting for request')
 
 
 
 
-def translate_path(url_path: str, base_path: str, check_existence=True, allow_extra=True):
-    # Build path one element at a time until we find a file
-    trans_path = base_path
-    path_len = 0
-    for part in url_path.split('/'):
-        path_len += len(part) + 1
-        # RFC 3986 says path components may have parameters, so look for any
-        # reserved delimiter characters and discard everything after one.
-        # Although the Gemini spec says not all of the components of generic URI
-        # syntax are supported, and disallowing path parameters seems in the
-        # spirit of the protocol, path parameters are not specifically mentioned
-        # so I try to do what feels safest and expect that they may show up.
-        part = unquote(re.split('[!$&\'()*+,;=]', part)[0])
-        trans_path = path.join(trans_path, part)
-        if check_existence and path.isfile(trans_path):
-            break
-    else:
-        if check_existence:
-            if path.isdir(trans_path):
-                trans_path = path.join(trans_path, 'index.gmi')
-                if not path.isfile(trans_path):
-                    raise CBSException(51, 'URL not found', trans_path)
-            else:
-                raise CBSException(51, 'URL not found', trans_path)
+def check_request(request: bytes) -> ParseResult:
+    # Gemini protocol specifies max 1024-byte URI
+    if len(request) > 1024:
+        raise CBSException(59, 'Request URI too long')
+
+    # The gemini protocol trades in UTF-8, but URIs can only contain ASCII
+    try:
+        request = request.decode('ascii')
+    except UnicodeDecodeError:
+        raise CBSException(59, 'Non-ASCII URI')
 
 
-    # Make sure the path didn't escape the base path.
-    trans_path = path.realpath(trans_path)
-    if path.commonpath([base_path, trans_path]) != base_path:
-        raise CBSException(59, 'Naughty directory traversal', trans_path)
+    # Parse URI and do some sanity checks
+    try:
+        parsed = urlsplit(request)  # May raise ValueError
+        uri_port = parsed.port  # Invalid port number raises ValueError on access
+    except ValueError:
+        raise CBSException(59, 'Invalid URI')
+    if parsed.scheme != 'gemini':
+        raise CBSException(59, 'Non-gemini scheme')
+    if parsed.username is not None:
+        raise CBSException(59, 'Username in URI disallowed')
+    if parsed.password is not None:
+        raise CBSException(59, 'Password in URI disallowed')
+    if parsed.fragment != '':
+        raise CBSException(59, 'Fragment in URI disallowed')
+    if any(delim in parsed.path for delim in ':?#[]@!$&\'(),;=*'):
+        raise CBSException(59, 'Invalid URI path')
+
+    return parsed
+
+
+def lookup_request(url_path: str, docroot: str):
+    # Build a resource path (and extra path for CGI)
+    translated = docroot
+    extra = ''
+    found = False
+    for part in url_path.split('/'):
+        unquoted = unquote(part)
+        if '/' in unquoted:  # Don't want to deal with escaped path delimiters
+            raise CBSException(59, 'Invalid URI path')
+        if not found:
+            translated = path.join(translated, unquoted)
+            if path.isfile(translated):
+                found = True
+        else:
+            extra += '/' + unquoted
+
+    # Look for an index if the path is a directory
+    if not found:
+        if path.isdir(translated):
+            translated = path.join(translated, 'index.gmi')
+            if not path.isfile(translated):
+                raise CBSException(51, 'URL not found')
+        else:
+            raise CBSException(51, 'URL not found')
 
 
-    # Grab all the leftovers verbatim for CGI scripts.
-    extra_path = url_path[max(path_len-1, 0):]
-    if extra_path and not allow_extra:
-        raise CBSException(59, 'Extra unexpected path information', extra_path)
+    # Make sure path doesn't escape the document root
+    abs_path = path.realpath(translated)
+    if path.commonpath([abs_path, docroot]) != docroot:
+        raise CBSException(59, 'Invalid URI path')
 
 
-    return trans_path, extra_path
+    return abs_path, translated, extra
 
 
 # ------------------------------------------------------------------------------
 # Serving
 
 
 
 
 # ------------------------------------------------------------------------------
 # Serving
 
 
-def serve_req(conn: SSL.Connection, addr, url: str, conf: dict):
-    # Attempt to parse the url and do basic validation
-    logging.info('Serving URL "{}"'.format(url))
-    try:
-        url_parsed = urlparse(url)
-    except ValueError:
-        raise CBSException(59, 'Could not parse URL', url)
-    if url_parsed.scheme != 'gemini':
-        raise CBSException(59, 'Non-gemini scheme', url_parsed.scheme)
-    if url_parsed.netloc == '':
-        raise CBSException(59, 'Netloc unspecified', url)
-
+def serve_req(conn: SSL.Connection, addr, url_parsed, conf: dict, absolute, relative, extra):
     # Parse the path information into a system path
     # Parse the path information into a system path
-    req_path, extra_path = translate_path(url_parsed.path, conf['servedir'])
+    url = urlunsplit(url_parsed)
+    req_path = absolute
+    extra_path = extra
+    logging.info('Serving URL "{}"'.format(url))
 
     # If the path is in the cgi directory then do some special CGI stuff.
     if conf['cgidir'] is not None and path.commonpath([conf['cgidir'], req_path]) == conf['cgidir']:
 
     # If the path is in the cgi directory then do some special CGI stuff.
     if conf['cgidir'] is not None and path.commonpath([conf['cgidir'], req_path]) == conf['cgidir']:
@@ -130,7 +141,8 @@ def serve_req(conn: SSL.Connection, addr, url: str, conf: dict):
 
 def serve_cgi(conn: SSL.Connection, addr, req_path, extra_path, url, conf: dict):
     cert = conn.get_peer_certificate()
 
 def serve_cgi(conn: SSL.Connection, addr, req_path, extra_path, url, conf: dict):
     cert = conn.get_peer_certificate()
-    extra_trans, _ = translate_path(extra_path, conf['servedir'], check_existence=False, allow_extra=False)
+    #extra_trans, _ = translate_path(extra_path, conf['servedir'], check_existence=False, allow_extra=False)
+    extra_trans = path.join(conf['servedir'], extra_path)
 
     # TODO: properly escape characters in DNs, see RFC 2253
     if cert is None:
 
     # TODO: properly escape characters in DNs, see RFC 2253
     if cert is None:
@@ -257,8 +269,10 @@ def main():
                 conn, addr = ssock.accept()
                 conn.do_handshake()
                 logging.info('Connection from {}'.format(addr))
                 conn, addr = ssock.accept()
                 conn.do_handshake()
                 logging.info('Connection from {}'.format(addr))
-                req = recv_req(conn)
-                serve_req(conn, addr, req, conf)
+                req = recv_request(conn)
+                url = check_request(req)
+                absolute, relative, extra = lookup_request(url.path, conf['servedir'])
+                serve_req(conn, addr, url, conf, absolute, relative, extra)
                 conn.shutdown()
                 conn.sock_shutdown(socket.SHUT_RDWR)
             except SSL.SysCallError as x:
                 conn.shutdown()
                 conn.sock_shutdown(socket.SHUT_RDWR)
             except SSL.SysCallError as x:
@@ -271,6 +285,7 @@ def main():
                 conn.sock_shutdown(socket.SHUT_RDWR)
             except Exception as x:
                 logging.error('Exception: {}'.format(x))
                 conn.sock_shutdown(socket.SHUT_RDWR)
             except Exception as x:
                 logging.error('Exception: {}'.format(x))
+                logging.error(traceback.format_exc())
                 conn.sendall('40 Server error\r\n'.encode('utf-8'))
                 conn.shutdown()
                 conn.sock_shutdown(socket.SHUT_RDWR)
                 conn.sendall('40 Server error\r\n'.encode('utf-8'))
                 conn.shutdown()
                 conn.sock_shutdown(socket.SHUT_RDWR)