diff --git a/wgenroll.py b/wgenroll.py index e82c972..ee6a7b4 100644 --- a/wgenroll.py +++ b/wgenroll.py @@ -1,5 +1,12 @@ -from flask import Flask, request +from datetime import datetime +from flask import Flask, request, Response from tinydb import TinyDB, Query +import secrets +import wireguard_py +import ipaddress + +WIREGUARD_DEVICE_NAME = b"wg0" +WIREGUARD_NETWORK = "172.16.0.0/20" app = Flask(__name__) db = TinyDB("./wgenroll.json") @@ -8,26 +15,58 @@ db = TinyDB("./wgenroll.json") def create_pin(): r = request.get_json() + if "issuer_code" not in r.keys(): + return Response({"status": "issuer_code is required"}, status=400) + issuer = r["issuer_code"] - pin = r["pin"] + pin = generate_pin() if db.contains(Query().issuer == issuer): db.insert({"pin": pin, "active": True}) - return {"status": "pin created"} + return {"status": "pin created", "pin": pin} else: - return {"status": "invalid issuer"} - + return Response({"status": "invalid issuer"}, status=403) @app.route("/enrollPeer", methods=["Post"]) def enroll_peer(): r = request.get_json() + if "pin" not in r.keys() or "public_key" not in r.keys(): + return Response({"status": "pin and public_key are required"}, status=400) + pin = r["pin"] public_key = r["public_key"] + #TODO validate key format + Pin = Query() + if db.contains((Pin.pin == pin) & (Pin.active == True)): + db.insert({"wg_public_key": public_key, "created_at": datetime.now().isoformat()}) + enroll_wireguard_peer(public_key) + # deactivate the pin now that it's been used' + db.update({"active": False}, {"pin": pin}) + else: + return Response({"status": "invalid pin"}, status=403) -def setIssuerCode(code: str): +def enroll_wireguard_peer(public_key: str): + wireguard_py.set_peer( + device_name=WIREGUARD_DEVICE_NAME, + pub_key=bytes(public_key, "utf-8"), + endpoint=None, + allowed_ips={ + ipaddress.ip_network(WIREGUARD_NETWORK), + }, + replace_allowed_ips=True) + +def set_issuer_code(code: str): db.insert({"issuer": code}) + print("issuer code created") -def deleteIssuerCode(code: str): +def delete_issuer_code(code: str): db.remove({"issuer": code}) + print("issuer code deleted") + +def generate_pin(): + return "".join([secrets.choice("1234567890") for x in range(6)]) + +if __name__ == "__main__": + set_issuer_code("1111")