]> git.the-white-hart.net Git - gemini/cbs-server.git/commitdiff
Refactor error handling
authorrs <>
Mon, 21 Feb 2022 01:57:53 +0000 (19:57 -0600)
committerrs <>
Mon, 21 Feb 2022 01:57:53 +0000 (19:57 -0600)
cbs-srv.py

index fc126c66fb2e84aeae46cd16d623ba202f1dffeb..d85dbce6cc21f0478a4b576c4aeb1d8f59f4425c 100755 (executable)
@@ -18,11 +18,18 @@ mimetypes.add_type('text/gemini', '.gmi')
 mimetypes.add_type('text/gemini', '.gemini')
 
 # ------------------------------------------------------------------------------
+# Helpers
 
 
-class CBSNotFound(Exception): pass
-class CBSTraversal(Exception): pass
-class CBSExtraPath(Exception): pass
+def accept_client_cert(conn, cert, err_num, err_depth, ret_code):
+    return True
+
+
+class CBSException(Exception):
+    def __init__(self, code, meta, logdata=None):
+        self.code = code
+        self.meta = meta
+        self.logdata = logdata
 
 
 def recv_req(conn: SSL.Connection, timeout=.1):
@@ -37,13 +44,11 @@ def recv_req(conn: SSL.Connection, timeout=.1):
                     logging.warning('Discarding data after URL line of request: {}'.format(data))
                 try:
                     req = lines[0].decode('ascii')
-                except Exception:
-                    logging.error('URL is not ascii: {}'.format(data))
-                    return None
+                except UnicodeDecodeError:
+                    raise CBSException(59, 'Non-ascii URL', data)
                 return req
         else:
-            logging.error('Timeout while waiting for URL')
-            return None
+            raise CBSException(59, 'Timeout while waiting for URL')
 
 
 def translate_path(url_path: str, base_path: str, check_existence=True, allow_extra=True):
@@ -67,24 +72,25 @@ def translate_path(url_path: str, base_path: str, check_existence=True, allow_ex
             if path.isdir(trans_path):
                 trans_path = path.join(trans_path, 'index.gmi')
                 if not path.isfile(trans_path):
-                    raise CBSNotFound(trans_path)
+                    raise CBSException(51, 'URL not found', trans_path)
             else:
-                raise CBSNotFound(trans_path)
+                raise CBSException(51, 'URL not found', trans_path)
 
     # Make sure the path didn't escape the base path.
     trans_path = path.realpath(trans_path)
     if path.commonpath([base_path, trans_path]) != base_path:
-        raise CBSTraversal(trans_path)
+        raise CBSException(59, 'Naughty directory traversal', trans_path)
 
     # Grab all the leftovers verbatim for CGI scripts.
     extra_path = url_path[max(path_len-1, 0):]
     if extra_path and not allow_extra:
-        raise CBSExtraPath(extra_path)
+        raise CBSException(59, 'Extra unexpected path information', extra_path)
 
     return trans_path, extra_path
 
 
 # ------------------------------------------------------------------------------
+# Serving
 
 
 def serve_req(conn: SSL.Connection, addr, url: str, conf: dict):
@@ -93,24 +99,14 @@ def serve_req(conn: SSL.Connection, addr, url: str, conf: dict):
     try:
         url_parsed = urlparse(url)
     except ValueError:
-        logging.error('Could not parse URL: "{}"'.format(url))
-        return serve_badreq(conn, "Could not parse URL")
+        raise CBSException(59, 'Could not parse URL', url)
     if url_parsed.scheme != 'gemini':
-        logging.error('Bad scheme: "{}"'.format(url_parsed.scheme))
-        return serve_badreq(conn, "Non-gemini scheme")
+        raise CBSException(59, 'Non-gemini scheme', url_parsed.scheme)
     if url_parsed.netloc == '':
-        logging.error('Netloc unspecified: "{}"'.format(url))
-        return serve_badreq(conn, "Netloc unspecified")
+        raise CBSException(59, 'Netloc unspecified', url)
 
     # Parse the path information into a system path
-    try:
-        req_path, extra_path = translate_path(url_parsed.path, conf['servedir'])
-    except CBSNotFound:
-        logging.error('URL not found: "{}"'.format(url))
-        return serve_notfound(conn)
-    except CBSTraversal:
-        logging.error('URL contains bad traversal: "{}"'.format(url))
-        return serve_badreq(conn, "Naughty directory traversal")
+    req_path, extra_path = translate_path(url_parsed.path, conf['servedir'])
 
     # If the path is in the cgi directory then do some special CGI stuff.
     if conf['cgidir'] is not None and path.commonpath([conf['cgidir'], req_path]) == conf['cgidir']:
@@ -118,34 +114,15 @@ def serve_req(conn: SSL.Connection, addr, url: str, conf: dict):
 
     # If the request is for a static file, there should be no extra path info
     if extra_path:
-        logging.warning('Extra path info after file: "{}"'.format(url_parsed.path))
-        return serve_notfound(conn)
+        raise CBSException(51, 'URL not found', 'extra path info: {}'.format(url_parsed.path))
 
     # Otherwise, serve up a static file
     return serve_file(conn, req_path)
 
 
-def serve_badreq(conn: SSL.Connection, msg=''):
-    conn.send('59 {}\r\n'.format(msg).encode('utf-8'))
-
-
-def serve_notfound(conn: SSL.Connection):
-    conn.send('51 Page not found\r\n'.encode('utf-8'))
-
-
-def serve_cgierror(conn: SSL.Connection, msg=''):
-    conn.send('42 {}\r\n'.format(msg).encode('utf-8'))
-
-
 def serve_cgi(conn: SSL.Connection, addr, req_path, extra_path, url, conf: dict):
     cert = conn.get_peer_certificate()
-
-    try:
-        extra_trans, _ = translate_path(extra_path, conf['servedir'], check_existence=False, allow_extra=False)
-    except CBSTraversal:
-        logging.error('Extra path contains bad traversal: "{}"'.format(extra_path))
-        return serve_badreq(conn, "Naughty directory traversal")
-
+    extra_trans, _ = translate_path(extra_path, conf['servedir'], check_existence=False, allow_extra=False)
     env = environ.copy()
 
     # RFC 3875
@@ -182,30 +159,28 @@ def serve_cgi(conn: SSL.Connection, addr, req_path, extra_path, url, conf: dict)
     try:
         proc = subprocess.run(req_path, env=env, timeout=10, capture_output=True, check=True)
     except subprocess.TimeoutExpired:
-        logging.error('CGI script timeout: "{}"'.format(req_path))
-        return serve_cgierror(conn, "CGI script timeout")
+        raise CBSException(42, 'CGI script timeout', req_path)
     except subprocess.CalledProcessError as x:
-        logging.error('CGI script returned error: "{}" -> {}'.format(req_path, x.returncode))
-        return serve_cgierror(conn, "CGI script returned error")
+        raise CBSException(42, 'CGI script error', '{} -> {}'.format(req_path, x.returncode))
     except PermissionError:
-        logging.error('CGI script permission error: "{}"'.format(req_path))
-        return serve_cgierror(conn, "CGI not executable")
-
+        raise CBSException(42, 'CGI not executable', req_path)
     conn.send(proc.stdout)
 
 
 def serve_file(conn: SSL.Connection, filedir):
     mime_type, encoding = mimetypes.guess_type(filedir)
-    with open(filedir, 'rb') as f:
-        conn.send('20 {}\r\n'.format(mime_type or 'application/octet-stream').encode('utf-8'))
-        conn.send(f.read())
+    try:
+        f = open(filedir, 'rb')
+        content = f.read()
+        f.close()
+    except Exception as x:
+        raise CBSException(40, 'Server error accessing content', x)
+    conn.send('20 {}\r\n'.format(mime_type or 'application/octet-stream').encode('utf-8'))
+    conn.send(content)
 
 
 # ------------------------------------------------------------------------------
-
-
-def accept_client_cert(conn, cert, err_num, err_depth, ret_code):
-    return True
+# Top level
 
 
 def main():
@@ -235,11 +210,15 @@ def main():
             conn, addr = ssock.accept()
             conn.do_handshake()
             logging.info('Connection from {}'.format(addr))
-            req = recv_req(conn)
-            if req is not None:
+            try:
+                req = recv_req(conn)
                 serve_req(conn, addr, req, conf)
-            else:
-                serve_badreq(conn, "Received invalid request")
+            except CBSException as x:
+                logging.error('{} {} {}'.format(x.code, x.meta, x.logdata))
+                conn.send('{} {}\r\n'.format(x.code, x.meta).encode('utf-8'))
+            except Exception as x:
+                logging.error('Exception: {}'.format(x))
+                conn.send('40 Server error\r\n')
             conn.shutdown()
             conn.sock_shutdown(socket.SHUT_RDWR)