diff --git a/.gitignore b/.gitignore index 3161a96..b1c88cd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class -_test_sqlitedb/*.db +test_sqlitedb/*.db csv_file/*.csv test.py test_args.py diff --git a/README.md b/README.md index 95c532d..d8c6797 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## Synopsis -This is a universal SQL client that is written in Python. It is a command prompt developed due to nuisance experience that author faces when he uses a traditional SQL client. This is also a very lightweight SQL client that tries its best to reduce the dependacy of the external libraries apart from the SQL connectors and webapp library. The author hopes that with the minimal coding effort, the tool can be fitted into any platform that installs Python, including modern embedded devices. +This is a universal SQL client that is written in Python. It is a command prompt developed due to nuisance experience that author faces when he used to use a traditional SQL client. This is also a very lightweight SQL client that tries its best to reduce the dependacy of the external libraries apart from the SQL connectors and webapp library. The author hopes that with the minimal coding effort, the tool can be fitted into any platform that installs Python, including modern embedded devices. ## Addressing Controversial Remarks @@ -57,14 +57,12 @@ This is a universal SQL client that is written in Python. It is a command prompt My tool: By typing ```webapp``` on the command prompt: - + ```unix - 1> webapp + 1> webapp ``` - A webapp will be spawn on http://127.0.0.1:5000/sql_webapp - - + A webapp will be spawn on http://127.0.0.1:5000/sql_web 4. Save the query to csv file @@ -183,7 +181,6 @@ The ```init.py``` script will write to config file based on your input ```unix prompt> select * from yahoo_123; - ``` ## Tested Database Connector @@ -192,6 +189,24 @@ The ```init.py``` script will write to config file based on your input 2. [MySQL](https://github.com/PyMySQL/mysqlclient-python) 3. [MsSQL](https://github.com/pymssql/pymssql) +## Interesting Stuff + +To experienbce the power of Falcon loaded with the shield of Vue + +1. On the prompt, just enter your usual SQL command: + + ```unix + prompt> falcon + ``` + +2. Go to the your browser and enter ```127.0.0.1:8041/sql_webapp_vue``` + +3. Enter the usual SQL command + + ```sql + select * from yahoo_123 + ``` + ## Things to Do 1. Write the test case (In Progress) @@ -211,4 +226,8 @@ The ```init.py``` script will write to config file based on your input ## History to remember (14/4/2019) -Facebook, Instagram and WhatsApp are currently down upon my submission of this repo. \ No newline at end of file +Facebook, Instagram and WhatsApp are currently down upon my submission of this repo. + +## My Talk at PyCon TW 2019 + +Here is the link to [my talk at PyCon TW 2019](https://www.youtube.com/watch?v=3tP3UHfv9rs) diff --git a/app/static/loading_2.gif b/app/static/loading_2.gif new file mode 100644 index 0000000..631bae9 Binary files /dev/null and b/app/static/loading_2.gif differ diff --git a/app/templates/webapp_vue.html b/app/templates/webapp_vue.html new file mode 100644 index 0000000..8d6f468 --- /dev/null +++ b/app/templates/webapp_vue.html @@ -0,0 +1,212 @@ + + + + + + + + + + + + + + + My Universal SQL Client WebApp + + + +

My Universal SQL Client WebApp

+

The current DB is: {{ db_name }}

+
+

Please enter your command below

+ +
+
+ + +
+
+ + + +
+
+
+
+
+
+ + + \ No newline at end of file diff --git a/app/templates/webdata.html b/app/templates/webdata.html index 15c0794..50c01e9 100644 --- a/app/templates/webdata.html +++ b/app/templates/webdata.html @@ -5,8 +5,8 @@ + + +

My Universal SQL Client WebApp

+

The current DB is: {{ db_name }}

+
+

Title: + +

+

X-Axis Unit (if any): + +

+

y-Axis Unit (if any): + +

+

Chart Type: + +

+

Please enter your command below

+ +
+
+ +
+
+ + + +
+
+
+
+
+ + + \ No newline at end of file diff --git a/app/templates/webvisual_vue_chartjs.html b/app/templates/webvisual_vue_chartjs.html new file mode 100644 index 0000000..a21d263 --- /dev/null +++ b/app/templates/webvisual_vue_chartjs.html @@ -0,0 +1,228 @@ + + + + + + + + My Universal SQL Client WebApp + + + +

My Universal SQL Client WebApp

+

The current DB is: {{ db_name }}

+
+

Title: + +

+

Data Label (if any): + +

+

DataSet Unit (if any): + +

+

X-Axis Unit (if any): + +

+

y-Axis Unit (if any): + +

+

Chart Type: + +

+

Please enter your command below

+ +
+
+ +
+
+ + + +
+
+
+
+
+ + + \ No newline at end of file diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000..9f01008 --- /dev/null +++ b/app/utils.py @@ -0,0 +1,22 @@ +import os +from interface.control import DBControlInterface + + +def get_data_from_db(db_client: DBControlInterface, incoming_data: str) -> (list, list): + + if os.environ["DB_TYPE"] == "sqlite3": + + sqlite3_client = DBControlInterface( + "sqlite3", db_nickname=os.environ["DB_NAME"] + ) + + sqlite3_client.connect() + return_data = sqlite3_client.command_interface("""{}""".format(incoming_data)) + + return [x[0] for x in sqlite3_client.cursor.description], return_data + + return_data = db_client.command_interface( + """{}""".format(incoming_data) + ) + + return [x[0] for x in db_client.cursor.description], return_data diff --git a/app/webapp.py b/app/webapp.py index c0e4527..8253ad6 100644 --- a/app/webapp.py +++ b/app/webapp.py @@ -7,6 +7,7 @@ #from database.models import DataBaseEngine from interface.control import DBControlInterface from chart.chart import generate_chart +from app.utils import get_data_from_db db_client = None @@ -25,31 +26,44 @@ app = Flask(__name__) -def _get_data_from_db(incoming_data: str) -> (list, list): +@app.route("/debug") +def get_debug(): - if os.environ["DB_TYPE"] == "sqlite3": + return "Webapp is running" - sqlite3_client = DBControlInterface( - "sqlite3", db_nickname=os.environ["DB_NAME"] - ) - - sqlite3_client.connect() - return_data = sqlite3_client.command_interface("""{}""".format(incoming_data)) +@app.route("/sql_api", methods=["POST", "GET"]) +def api_post_command(): - return [x[0] for x in sqlite3_client.cursor.description], return_data + if request.method == "POST": - return_data = db_client.command_interface( - """{}""".format(incoming_data) - ) + incoming_data = request.get_json() + incoming_sql_statement = incoming_data["sql_command"] - return [x[0] for x in db_client.cursor.description], return_data + table_header, db_response = get_data_from_db(db_client, incoming_sql_statement) + db_response_list = [ + dict(zip(table_header, db_data)) for db_data in db_response + ] -@app.route("/debug") -def get_debug(): + if ("item_per_page" in incoming_data.keys()) and ("current_index" in incoming_data.keys()): - return "Webapp is running" + current_index = incoming_data["current_index"] + item_per_page = incoming_data["item_per_page"] + db_response_list = db_response_list[current_index: current_index + item_per_page] + return jsonify( + { + "status": 200, + "sql_response": db_response_list + } + ) + + return jsonify( + { + "status": 403, + "message": "Invalid request" + } + ) @app.route("/sql_webapp", methods=["POST", "GET"]) def post_command(): @@ -57,7 +71,7 @@ def post_command(): if request.method == "POST": incoming_data = request.form.get("textbox") - table_header, db_response = _get_data_from_db(incoming_data) + table_header, db_response = get_data_from_db(db_client, incoming_data) return render_template( "response.html", @@ -75,7 +89,7 @@ def generate_chart_js(): incoming_data = request.form - _, db_response = _get_data_from_db(incoming_data["sql_query"]) + _, db_response = get_data_from_db(db_client, incoming_data["sql_command"]) chart_data = generate_chart( incoming_data["chart"], diff --git a/app/webapp_falcon.py b/app/webapp_falcon.py new file mode 100644 index 0000000..5231ffe --- /dev/null +++ b/app/webapp_falcon.py @@ -0,0 +1,201 @@ +import json +import re +import os +import time +import falcon +import jinja2 +#from database.models import DataBaseEngine +import datetime +import decimal +from interface.control import DBControlInterface +from app.utils import get_data_from_db + + +db_client = None +api_storage = None +table_key = None +previous_command = None +current_index = 0 + + +if os.environ["DB_TYPE"] != "sqlite3": + + db_client = DBControlInterface( + os.environ["DB_TYPE"], + db_nickname=os.environ["DB_NAME"] + ) + + db_client.connect() + + +def _json_serializer(obj): + + if isinstance(obj, datetime.datetime): + + return str(obj) + + elif isinstance(obj, decimal.Decimal): + + return str(obj) + + raise TypeError('Cannot serialize {!r} (type {})'.format(obj, type(obj))) + + +def _return_query_result(sql_cursor): + + return ( + sql_cursor.cursor.fetchall(), + [x[0] for x in sql_cursor.cursor.description] + ) + + +class WebDebug: + + def on_get(self, request, response): + + response.content_type = falcon.MEDIA_TEXT + response.body = "Webapp is running" + + +class WebSQLAPI: + + def on_get(self, request, response): + + response.content_type = falcon.MEDIA_JSON + + response.media = { + "status": falcon.HTTP_500, + "message": "Invalid Request" + } + + def on_post(self, request, response): + + data_from_request = request.media + + try: + + if "current_index" in data_from_request and "item_per_page" in data_from_request: + + global api_storage + global table_key + global previous_command + + if not api_storage or previous_command != data_from_request["sql_command"]: + + table_key, response_db = get_data_from_db( + db_client, data_from_request["sql_command"] + ) + + previous_command = data_from_request["sql_command"] + + api_storage = [ + { table_key[i]: str(r) if (isinstance(r, datetime.datetime) or isinstance(r, decimal.Decimal)) else r + for i, r in enumerate(response)} + for response in response_db + ] + + current_index = data_from_request["current_index"] + item_per_page = data_from_request["item_per_page"] + + response.media = { + "status": 200, + "sql_header": table_key, + "sql_response": api_storage[current_index: current_index + item_per_page], + "sql_length": len(api_storage) + } + + else: + + table_key, response_db = get_data_from_db(db_client, data_from_request["sql_command"]) + + response.media = { + "status": 200, + "sql_header": table_key, + "sql_response": [ + [ + str(r) if (isinstance(r, datetime.datetime) or isinstance(r, decimal.Decimal)) else r for r in response + ] for response in response_db + ] + } + + except Exception as e: + + response.media = { + "status": 500, + "error_msg": str(e) + } + +class WebSQLChartjsVue: + + def on_get(self, request, response): + + response.content_type = falcon.MEDIA_JSON + + response.media = { + "status": falcon.HTTP_500, + "message": "Invalid Request" + } + + def on_post(self, request, response): + + pass + """incoming_data = request.media + + _, db_response = get_data_from_db(db_client, incoming_data["sql_statement"]) + + chart_data = generate_chart( + incoming_data["chart"], + incoming_data["title"], + incoming_data["dataset_label"], + db_response, + ) + + #print(json.dumps(chart_data, indent=4)) + return render_template( + "webvisual.html", + height=incoming_data["height"], + unit=incoming_data["dataset_unit"], + incoming_data_json=json.dumps(chart_data) + ) + """ + +class WebSQLFrontVue: + + def on_get(self, request, response): + + loader = jinja2.FileSystemLoader("app/templates") + load_template_file = jinja2.Environment(loader=loader).get_template("webapp_vue.html") + + response.content_type = falcon.MEDIA_HTML + response.body = load_template_file.render(db_name=os.environ["DB_NAME"]) + + +class WebSQLStatic: + + def on_get(self, request, response, filename): + + response.status = falcon.HTTP_200 + response.content_type = 'appropriate/content-type' + + with open("app/static/{}".format(filename), 'rb') as f: + + response.body = f.read() + + +class WebSQLVisualVue: + + def on_get(self, request, response): + + loader = jinja2.FileSystemLoader("app/templates") + load_template_file = jinja2.Environment(loader=loader).get_template("webvisual_vue.html") + + response.content_type = falcon.MEDIA_HTML + response.body = load_template_file.render(db_name=os.environ["DB_NAME"]) + + +app = falcon.API() +app.add_route("/debug", WebDebug()) +app.add_route("/sql_api", WebSQLAPI()) +app.add_route("/sql_webapp_vue", WebSQLFrontVue()) +app.add_route("/static/{filename}", WebSQLStatic()) +app.add_route("/sql_visual_vue", WebSQLVisualVue()) \ No newline at end of file diff --git a/chart/colours.py b/chart/colours.py index 19304fb..5903301 100644 --- a/chart/colours.py +++ b/chart/colours.py @@ -1,7 +1,7 @@ import re -def generate_colours(colour_name: str=None, rgb_tuple: tuple=(), alpha: int=1.0): +def generate_colours(colour_name: str=None, rgb_code: tuple=(), alpha: int=1.0): defined_color_dict = { "black": (0,0,0,alpha), "white": (255,255,255,alpha), @@ -135,10 +135,10 @@ def generate_colours(colour_name: str=None, rgb_tuple: tuple=(), alpha: int=1.0) ) ) - if len(rgb_tuple) == 3: + if len(rgb_code) == 3: return 'rgba({0},{1})'.format( - ','.join([str(x) for x in rgb_tuple]), alpha + ','.join([str(x) for x in rgb_code]), alpha ) return defined_color_dict \ No newline at end of file diff --git a/config/command_def.json b/config/command_def.json index 8a98034..53237d1 100644 --- a/config/command_def.json +++ b/config/command_def.json @@ -1,10 +1,22 @@ { - "nas": { - "sql": "select * from nas;", - "parameters": false + "radius": { + "nas": { + "sql": "select * from nas;", + "parameters": false + }, + "demo": { + "sql": "select * from {};", + "parameters": true + } }, - "demo": { - "sql": "select * from {};", - "parameters": true + "city_line": { + "city": { + "sql": "select * from cities", + "parameters": false + }, + "any_table": { + "sql": "select * from {}", + "parameters": true + } } } \ No newline at end of file diff --git a/config/config.json b/config/config.json index edfa89a..d23d710 100644 --- a/config/config.json +++ b/config/config.json @@ -13,5 +13,14 @@ }, "sqlite3_person": { "db_name": "test_sqlitedb/test.db" + }, + "city_line":{ + "db_name": "test_sqlitedb/cities_lines.db" + }, + "library": { + "server": "localhost", + "username": "postgres", + "password": "postgres12345", + "db_name": "BookLibrary" } } \ No newline at end of file diff --git a/database/models.py b/database/models.py index 8be5942..0490c78 100644 --- a/database/models.py +++ b/database/models.py @@ -25,15 +25,19 @@ def __init__(self, db_engine: str, db_nickname=None): self.conn = None self.cursor = None self.db_engine = db_engine - self.db_nickname = db_nickname + self.db_nickname = db_nickname #self.sqlite3_filename = sqlite3_filename def connect(self): + if self.db_nickname not in self.configuration: + + print(f"Table {self.db_nickname} Not Found in config.json") + db_module = import_module(self.db_engine) try: - + if self.db_engine == "sqlite3": self.conn = db_module.connect( @@ -43,10 +47,10 @@ def connect(self): else: self.conn = db_module.connect( - self.configuration[self.db_nickname]["server"], - self.configuration[self.db_nickname]["username"], - self.configuration[self.db_nickname]["password"], - self.configuration[self.db_nickname]["db_name"] + host=self.configuration[self.db_nickname]["server"], + database=self.configuration[self.db_nickname]["db_name"], + user=self.configuration[self.db_nickname]["username"], + password= self.configuration[self.db_nickname]["password"], ) except db_module.InterfaceError: @@ -69,13 +73,15 @@ def execute(self, sql_statement:str): self.cursor.execute(sql_statement) - if re.search(r"(?i)(CREATE|INSERT|UPDATE|DELETE|ALTER|DELELE).+", sql_statement): + sql_regex = re.compile(r"^(?i)(CREATE|ALTER|UPDATE|INSERT|DELETE|VIEW|ROLLBACK)$") + + if sql_regex.match(sql_statement.split()[0]): self.conn.commit() def rollback(self): - self.cursor.rollback() + self.conn.rollback() def retrieve_table(self): diff --git a/debug.txt b/debug.txt new file mode 100644 index 0000000..e69de29 diff --git a/interface/control.py b/interface/control.py index 0c1e56b..3f8e875 100644 --- a/interface/control.py +++ b/interface/control.py @@ -2,11 +2,12 @@ import json import os import re -from database.models import DataBaseEngine from typing import List +import waitress +from database.models import DataBaseEngine -def parse_lambda(key, value, func): +def parse_lambda(value: dict, func): if value["parameters"]: @@ -14,7 +15,7 @@ def parse_lambda(key, value, func): """{}""".format(value["sql"]).format(x) ) - return lambda : func("""{}""".format(value["sql"])) + return lambda: func("""{}""".format(value["sql"])) class DBControlInterface(DataBaseEngine): @@ -23,22 +24,26 @@ class DBControlInterface(DataBaseEngine): defined_attributes = None if os.path.exists("config/command_def.json"): - + with open("config/command_def.json", "r") as def_file: defined_attributes = json.loads(def_file.read()) def __init__(self, db_engine: str, db_nickname=None): - for key, value in self.defined_attributes.items(): - - setattr( - self, - "get_{}".format(key), - parse_lambda(key, value, self.get_sql) - ) - - super().__init__(db_engine, db_nickname=db_nickname) + self.db_nickname = db_nickname + + if self.db_nickname and self.db_nickname in self.defined_attributes.keys(): + + for key, value in self.defined_attributes[self.db_nickname].items(): + + setattr( + self, + "get_{}".format(key), + parse_lambda(value, self.get_sql) + ) + + super().__init__(db_engine, db_nickname=self.db_nickname) def _write_to_file(self, filename: str, query_result: List): @@ -49,13 +54,13 @@ def _write_to_file(self, filename: str, query_result: List): if self.cursor.description: csv_writer.writerow([desc[0] for desc in self.cursor.description]) - + csv_writer.writerows(query_result) return "Query result written to file: {0}".format(filename) def _save_query_to_file(self, filename: str, sql_command: str=None): - + query_result = None command_buffer = None @@ -74,13 +79,13 @@ def _save_query_to_file(self, filename: str, sql_command: str=None): elif sql_command and re.search(r"(?i)(select)\s.+", sql_command): query_result = self.command_interface(sql_command) - + self._write_to_file(filename, query_result) def get_db(self): return list(self.configuration.keys()) - + def get_help(self): return """ @@ -122,13 +127,13 @@ def get_r(self): return "No Previous Command Found!" return "Your Previous Command is: {}".format(self.command_stored_in_buffer) - + def get_t(self): if not self.command_stored_in_buffer: return "You cannot execute any command" - + else: buffer_list = self.command_stored_in_buffer.split(" ") @@ -143,31 +148,34 @@ def get_t(self): def get_save(self, args_list: list): if os.path.exists(args_list[0]): - + return "You Have saved your previous query to file: {0}".format(args_list[0]) - - else: - if len(args_list) > 1: + if len(args_list) > 1: - self._save_query_to_file( - args_list[-1], #filename - sql_command= " ".join(args_list[0:-1]) #sqlcommand - ) + self._save_query_to_file( + args_list[-1], #filename + sql_command=" ".join(args_list[0:-1]) #sqlcommand + ) - elif len(args_list) == 1: - - self._save_query_to_file( - args_list[0], #filename - self.command_stored_in_buffer - ) + elif len(args_list) == 1: + + self._save_query_to_file( + args_list[0], #filename + self.command_stored_in_buffer + ) - return "File {} written successfully".format(args_list[0]) + return "File {} written successfully".format(args_list[0]) def get_column(self, table_name: str): return self.retrieve_column_name(table_name[0]) + def get_rollback(self): + + self.rollback() + return "Rollback is triggered!" + def get_webapp(self): from app.webapp import app @@ -185,35 +193,37 @@ def get_sql(self, input_command): return "Not a valid SQL Expression!" - #def command_interface(self, input_command: str, command_stored_in_buffer: str): + def get_falcon(self): + + from app.webapp_falcon import app + + waitress.serve(app, host='127.0.0.1', port=8041, url_scheme='https') + def command_interface(self, input_command: str): input_command_list = input_command.split(" ") if hasattr(self, "get_{}".format(input_command_list[0])): - instance_method = getattr(self, - "get_{}".format(input_command_list[0]) - ) + instance_method = getattr(self, + "get_{}".format(input_command_list[0]) + ) if len(input_command_list) == 1: - + return instance_method() - elif len(input_command_list) > 1: - + if len(input_command_list) > 1: + if input_command_list[0] in ["save", "column"]: return instance_method(input_command_list[1:]) - elif input_command_list[0] == "table": - - return instance_method() + if input_command_list[0] in [ + key for key in self.defined_attributes[self.db_nickname].keys() + if dict(self.defined_attributes[self.db_nickname])[key]["parameters"] + ]: - elif input_command_list[0] in [ - key for key in self.defined_attributes.keys() - if dict(self.defined_attributes)[key]["parameters"]]: - return instance_method(input_command_list[1]) return self.get_sql(input_command) @@ -221,16 +231,14 @@ def command_interface(self, input_command: str): class DBInterface(DBControlInterface): - def _delay_mode(self, data_from_table: List, line_to_display=20): + def _delay_mode(self, data_from_table: List, display_line=20): - data_length = len(data_from_table) - - if data_length < line_to_display: + if len(data_from_table) < display_line: for data in data_from_table: print("|".join([str(item) for item in data])) - + else: index = 0 @@ -240,9 +248,9 @@ def _delay_mode(self, data_from_table: List, line_to_display=20): if next_input == ">": - if index < int(len(data_from_table)/line_to_display) + 1: + if index < int(len(data_from_table)/display_line) + 1: - init_index, final_index = index*line_to_display, (index + 1)*line_to_display + init_index, final_index = index*display_line, (index + 1)*display_line for query in data_from_table[init_index: final_index]: @@ -260,14 +268,14 @@ def _delay_mode(self, data_from_table: List, line_to_display=20): index -= 1 - for query in data_from_table[index*line_to_display: (index + 1)*line_to_display]: + for query in data_from_table[index*display_line: (index + 1)*display_line]: print("|".join([str(item) for item in query])) else: print("You are at the top of the page") - + elif next_input == "x": break @@ -280,39 +288,38 @@ def _delay_mode(self, data_from_table: List, line_to_display=20): def _get_output(self, data_from_db=None, delay_mode: bool=False, line_to_display: int=10): - #print("\n") if isinstance(data_from_db, str): - if len(data_from_db) > 0: + if data_from_db: print("{}\n".format(data_from_db)) - - elif isinstance(data_from_db, tuple) or isinstance(data_from_db, list): - print("\n") + elif isinstance(data_from_db, (tuple, list)): + table_header = "|".join([desc[0] for desc in self.cursor.description]) + print("\n") print(table_header) print("_" * len(table_header)) if delay_mode: self._delay_mode(data_from_db, line_to_display) - + else: for query in data_from_db: print("|".join([str(entry) for entry in query])) - + print("\n") - + def get_result_output(self, input_command: str, command_stored_in_buffer): self.command_stored_in_buffer = command_stored_in_buffer - if len(input_command) == 0: + if not input_command: pass @@ -320,18 +327,16 @@ def get_result_output(self, input_command: str, command_stored_in_buffer): self._get_output( data_from_db=self.command_interface( - input_command.split("|")[0], + input_command.split("|")[0], ), - delay_mode=True, + delay_mode=True, line_to_display=10 ) - + else: - + self._get_output( data_from_db=self.command_interface( input_command ) ) - - \ No newline at end of file diff --git a/interface/view.py b/interface/view.py index aaac700..d9bfde9 100644 --- a/interface/view.py +++ b/interface/view.py @@ -10,7 +10,6 @@ def front_prompt(database_type: str=None, database_nickname: str=None): - def platform_prompt(platform_name: str) -> str: platform_prompt_str = None @@ -32,17 +31,14 @@ def platform_prompt(platform_name: str) -> str: return platform_prompt_str - default_prompt_str = "{0} ".format( - str(time.ctime(time.time())), - ) + default_prompt_str = f"{str(time.ctime(time.time()))} " if database_type: if database_nickname: - default_prompt_str += "{0} {1} {2}> ".format( - database_type, + default_prompt_str += "{0} {1} > ".format( database_nickname, platform_prompt(sys.platform) ) @@ -84,20 +80,25 @@ def database_interface(database_type: str, db_nickname: str=None): received_command = front_prompt(database_type) - os.environ["DB_TYPE"] = database_type - - db_client = DBInterface(database_type, db_nickname=db_nickname) os.environ["DB_NAME"] = db_nickname + os.environ["DB_TYPE"] = database_type + db_client = DBInterface(database_type, db_nickname=db_nickname) db_client.connect() - + while (received_command not in exit_command_list): - + try: - if re.search(r"switch \w+$", received_command): + if re.search(r"switch \w+", received_command): + + received_command_list = received_command.split(" ") + + if len(received_command_list) == 3 and (database_type != received_command_list[1]): - database_interface(database_type, received_command.split(" ")[1]) + database_type = received_command_list[1] + + database_interface(database_type, db_nickname=received_command_list[-1]) else: @@ -111,7 +112,7 @@ def database_interface(database_type: str, db_nickname: str=None): if not (received_command in command_keys): command_buffer = received_command - + received_command = front_prompt(database_type, db_nickname) @@ -123,7 +124,7 @@ def initial_interface() -> str: if re.search(r"^connect\s([\w\s])+\s?([\w\s])*$", command): - return command.split(" ")[1:] + return command.split(" ")[1] else: diff --git a/main.py b/main.py index 244bb3f..8c5d8fb 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import os import re import sys +import waitress from interface.view import initial_interface from interface.view import database_interface from tools.init import init_db @@ -13,11 +14,12 @@ def db_main(sql_lib_definition_dict: dict): print("\nWelcome to the Universal SQL Client:\n") - if len(sys.argv) < 3: + if len(sys.argv) == 2: - database_type, database_nickname = initial_interface() + database_type = sql_lib_definition_dict[sys.argv[1]] + database_nickname = initial_interface() - else: + elif len(sys.argv) > 2: database_type = sql_lib_definition_dict[sys.argv[1]] database_nickname = sys.argv[2] @@ -51,14 +53,23 @@ def main(): elif len(sys.argv) > 3 and len(sys.argv) <= 4: - if sys.argv[1] == "webapp": + os.environ["DB_TYPE"] = sql_lib_definition_dict[sys.argv[2]] + os.environ["DB_NAME"] = sys.argv[3] - os.environ["DB_TYPE"] = sql_lib_definition_dict[sys.argv[2]] - os.environ["DB_NAME"] = sys.argv[3] + if sys.argv[1] == "webapp": from app.webapp import app app.run(host="127.0.0.1", port=5000) + + elif sys.argv[1] == "falcon": + + from app.webapp_falcon import app + + waitress.serve(app, host='127.0.0.1', port=8041, url_scheme='https') + + del os.environ["DB_TYPE"] + del os.environ["DB_NAME"] else: diff --git a/pycon_demo.txt b/pycon_demo.txt new file mode 100644 index 0000000..3bca11b --- /dev/null +++ b/pycon_demo.txt @@ -0,0 +1,3 @@ +select cities.name, cast(tracks.length as decimal) +as track_length from cities join tracks on cities.id = tracks.city_id +group by track_length having track_length > 50000 \ No newline at end of file diff --git a/radius b/radius new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index eec9ae8..3093f31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ flask pymssql -mysqlclient \ No newline at end of file +mysqlclient +falcon +waitress \ No newline at end of file diff --git a/run_test.py b/run_test.py new file mode 100644 index 0000000..8082028 --- /dev/null +++ b/run_test.py @@ -0,0 +1,8 @@ +import unittest + + +loader = unittest.TestLoader() +suite = loader.discover("tests") + +runner = unittest.TextTestRunner() +runner.run(suite) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_colors.py b/tests/test_colors.py new file mode 100644 index 0000000..4dd8c1a --- /dev/null +++ b/tests/test_colors.py @@ -0,0 +1,55 @@ +import unittest +from chart.colours import generate_colours + + +class TestColour(unittest.TestCase): + + def setUp(self): + + self.alpha = 0.5 + self.colours = { + "maroon": (128,0,0,self.alpha), "olive": (128,128,0,self.alpha), + "green": (0,128,0,self.alpha), "purple": (128,0,128,self.alpha), + "teal": (0,128,128,self.alpha), "navy": (0,0,128,self.alpha), + "dark red": (139,0,0,self.alpha), "brown": (165,42,42,self.alpha), + "firebrick": (178,34,34,self.alpha), "crimson": (220,20,60,self.alpha), + "tomato": (255,99,71,self.alpha), "coral": (255,127,80,self.alpha), + "indian_red": (205,92,92,self.alpha), "light_coral": (240,128,128,self.alpha), + "dark_salmon": (233,150,122,self.alpha), "salmon": (250,128,114,self.alpha), + "light_salmon": (255,160,122,self.alpha), "orange_red": (255,69,0,self.alpha), + "dark_orange": (255,140,0,self.alpha), "orange": (255,165,0,self.alpha), + "gold": (255,215,0,self.alpha), "dark_golden_rod": (184,134,11,self.alpha), + "golden_rod": (218,165,32,self.alpha), "pale_golden_rod": (238,232,170,self.alpha), + "dark_khaki": (189,183,107,self.alpha), "khaki": (240,230,140,self.alpha), + "yellow_green": (154,205,50,self.alpha), "dark_olive_green": (85,107,47,self.alpha), + "olive_drab": (107,142,35,self.alpha), "lawn_green": (124,252,0,self.alpha), + "chart_reuse": (127,255,0,self.alpha), "green_yellow": (173,255,47,self.alpha), + "dark_green": (0,100,0,self.alpha), "forest_green": (34,139,34,self.alpha), + "lime_green": (50,205,50,self.alpha) + } + + def test_color(self): + + for colour, colour_tuple in self.colours.items(): + + colour_tuple_str = "rgba({})".format( + ",".join([str(x) for x in colour_tuple]) + ) + + self.assertEqual( + generate_colours(colour, alpha=self.alpha), + colour_tuple_str + ) + + def test_rgb_code(self): + + for _, colour_tuple in self.colours.items(): + + colour_tuple_str = "rgba({})".format( + ",".join([str(x) for x in colour_tuple]) + ) + + self.assertEqual( + generate_colours(rgb_code=colour_tuple[:3], alpha=self.alpha), + colour_tuple_str + ) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..7793dd0 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,55 @@ +import unittest +from database.models import DataBaseEngine + + +class TestModelSqlite3(unittest.TestCase): + + def setUp(self): + + self.db_engine = DataBaseEngine("sqlite3", "sqlite3_person") + self.db_engine.connect() + + def test_table(self): + + self.assertIsInstance( + self.db_engine.retrieve_table(), + list + ) + + def test_sql(self): + + self.db_engine.execute("select * from Person") + self.assertIsInstance( + self.db_engine.cursor.fetchall(), + list + ) + + +class TestModelMySQL(unittest.TestCase): + + def setUp(self): + + self.db_engine = DataBaseEngine("MySQLdb", "radius") + self.db_engine.connect() + + def test_sql(self): + + self.db_engine.execute("select * from radacct") + self.assertIsInstance( + self.db_engine.cursor.fetchall(), + tuple + ) + + def test_column(self): + + self.assertIsInstance( + self.db_engine.retrieve_column_name("radacct"), + tuple + ) + + def test_table(self): + + self.assertIsInstance( + self.db_engine.retrieve_table(), + tuple + ) diff --git a/test_control.py b/tests/test_mysql_control.py similarity index 94% rename from test_control.py rename to tests/test_mysql_control.py index 6e44038..696eb69 100644 --- a/test_control.py +++ b/tests/test_mysql_control.py @@ -1,10 +1,10 @@ import os import unittest import time -from interface.control_new import DBControlInterface +from interface.control import DBControlInterface -class ControlTest(unittest.TestCase): +class MySQLControlTest(unittest.TestCase): def setUp(self): @@ -49,7 +49,7 @@ def test_T(self): tuple ) - """def test_Command(self): + def test_Command(self): command_dict = { "db": list, @@ -82,7 +82,6 @@ def test_T(self): ), value ) - """ def test_save(self): @@ -107,5 +106,3 @@ def test_demo(self): self.db.get_demo("radacct"), tuple ) - -unittest.main() \ No newline at end of file diff --git a/tests/test_sqlite3.py b/tests/test_sqlite3.py new file mode 100644 index 0000000..6697666 --- /dev/null +++ b/tests/test_sqlite3.py @@ -0,0 +1,25 @@ +import unittest +from database.models import DataBaseEngine + + +class TestModelSqlite3(unittest.TestCase): + + def setUp(self): + + self.db_engine = DataBaseEngine("sqlite3", "city_line") + self.db_engine.connect() + + def test_table(self): + + self.assertIsInstance( + self.db_engine.retrieve_table(), + list + ) + + def test_sql(self): + + self.db_engine.execute("select * from track_lines") + self.assertIsInstance( + self.db_engine.cursor.fetchall(), + list + ) diff --git a/tests/test_sqlite3_control.py b/tests/test_sqlite3_control.py new file mode 100644 index 0000000..34c4469 --- /dev/null +++ b/tests/test_sqlite3_control.py @@ -0,0 +1,101 @@ +import os +import unittest +import time +from interface.control import DBControlInterface + + +class MySQLControlTest(unittest.TestCase): + + def setUp(self): + + self.command_stored_in_buffer = "select * from track_lines;" + self.db = DBControlInterface("sqlite3", db_nickname="city_line") + self.db.connect() + + def test_table(self): + + self.assertIsInstance(self.db.get_table(), list) + + def test_Help(self): + + self.assertIsInstance(self.db.get_help(), str) + + def test_SQL(self): + + self.assertIsInstance( + self.db.get_sql("""select * from cities"""), + list + ) + + def test_R(self): + + self.db.command_stored_in_buffer = self.command_stored_in_buffer + self.assertEqual( + self.db.get_r(), + "Your Previous Command is: {}".format(self.command_stored_in_buffer) + ) + + def test_T(self): + + self.db.command_stored_in_buffer = self.command_stored_in_buffer + self.assertIsInstance( + self.db.get_t(), + list + ) + + def test_Command(self): + + command_dict = { + "db": list, + "help": str, + "r": str, + "t": str, + "table": list, + "sql": ( + "select * from track_lines", + list + ) + } + + for key, value in command_dict.items(): + + if key == "sql": + + self.assertIsInstance( + self.db.command_interface( + value[0], + ), + value[1] + ) + + else: + self.assertIsInstance( + self.db.command_interface( + key + ), + value + ) + + def test_save(self): + + self.db.get_save(["select * from track_lines", "test_track_lines.csv"]) + + time.sleep(0.5) + + self.assertTrue( + os.path.exists("csv_file/test_track_lines.csv") + ) + + def test_city(self): + + self.assertIsInstance( + self.db.get_city(), + list + ) + + def test_any_table(self): + + self.assertIsInstance( + self.db.get_any_table("track_lines"), + list + ) diff --git a/tests/test_webapp.py b/tests/test_webapp.py new file mode 100644 index 0000000..9206d8e --- /dev/null +++ b/tests/test_webapp.py @@ -0,0 +1,137 @@ +import os +import json +import unittest +from falcon import testing +from flask import jsonify + + +class TestWebAppFalcon(unittest.TestCase): + + def setUp(self): + + os.environ["DB_TYPE"] = "sqlite3" + os.environ["DB_NAME"] = "city_line" + + from app.webapp_falcon import app + + self.app = app + self.client = testing.TestClient(self.app) + + def test_webapp_debug(self): + + result = self.client.simulate_get("/debug") + self.assertEqual(result.status_code, 200) + self.assertEqual(result.text, "Webapp is running") + + def test_webapp_sql_api_get(self): + + result_get = self.client.simulate_get("/sql_api") + self.assertEqual(result_get.status_code, 200) + self.assertEqual( + result_get.json["status"], "500 Internal Server Error" + ) + + def test_webapp_sql_api_post(self): + + current_item_page = 50 + result_post = self.client.simulate_post( + "/sql_api", + json={ + "sql_command": "select * from cities", + "current_index": 0, + "item_per_page": current_item_page + } + ) + + self.assertEqual(result_post.status_code, 200) + + self.assertEqual(len(result_post.json["sql_response"]), current_item_page) + self.assertSetEqual( + set(result_post.json["sql_header"]), + set(result_post.json["sql_response"][0].keys()) + ) + + def test_webapp_sql_api_post_command(self): + + current_item_page = 7 + result_post = self.client.simulate_post( + "/sql_api", + json={ + "sql_command": "table", + "current_index": 0, + "item_per_page": current_item_page + } + ) + + self.assertEqual(result_post.status_code, 200) + + self.assertEqual(len(result_post.json["sql_response"]), current_item_page) + self.assertSetEqual( + set(result_post.json["sql_header"]), + set(result_post.json["sql_response"][0].keys()) + ) + + +class TestWebAppFlask(unittest.TestCase): + + def setUp(self): + + os.environ["DB_TYPE"] = "sqlite3" + os.environ["DB_NAME"] = "city_line" + + from app.webapp import app + + self.app = app.test_client() + self.app.testing = True + + def test_debug(self): + + result = self.app.get("/debug") + self.assertEqual(result.status_code, 200) + self.assertEqual(result.data.decode('utf-8'), "Webapp is running") + + def test_sql_api_get(self): + + result = self.app.get("/sql_api") + self.assertEqual(result.status_code, 200) + self.assertEqual( + result.json["status"], 403 + ) + + def test_sql_api_post(self): + + current_item_page = 50 + test_data = { + "sql_command": "select * from cities", + "current_index": 0, + "item_per_page": current_item_page + } + + result = self.app.post( + "/sql_api", + data=json.dumps(test_data), content_type="application/json" + ) + + self.assertEqual(result.status_code, 200) + self.assertEqual( + len(result.json["sql_response"]), current_item_page + ) + + def test_webapp_sql_api_post_command(self): + + current_item_page = 7 + + result_post = self.app.post( + "/sql_api", + data=json.dumps( + { + "sql_command": "table", + "current_index": 0, + "item_per_page": current_item_page + } + ), + content_type="application/json" + ) + + self.assertEqual(result_post.status_code, 200) + self.assertEqual(len(result_post.json["sql_response"]), current_item_page) \ No newline at end of file