diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..bb6c970 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,6 @@ +{ + "tasks": { + "build": "echo \"No build steps required for this repository.\"", + "test": "pip install pytest && pytest" + } +} \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..28ab5a7 --- /dev/null +++ b/main.py @@ -0,0 +1,177 @@ +import threading +import requests +import json +import time +import sys +import http.server + +token = None + +def setup(): + resp = requests.post('https://github.com/login/device/code', headers={ + 'accept': 'application/json', + 'editor-version': 'Neovim/0.6.1', + 'editor-plugin-version': 'copilot.vim/1.16.0', + 'content-type': 'application/json', + 'user-agent': 'GithubCopilot/1.155.0', + 'accept-encoding': 'gzip,deflate,br' + }, data='{"client_id":"Iv1.b507a08c87ecfe98","scope":"read:user"}') + + + # Parse the response json, isolating the device_code, user_code, and verification_uri + resp_json = resp.json() + device_code = resp_json.get('device_code') + user_code = resp_json.get('user_code') + verification_uri = resp_json.get('verification_uri') + + # Print the user code and verification uri + print(f'Please visit {verification_uri} and enter code {user_code} to authenticate.') + + + while True: + time.sleep(5) + resp = requests.post('https://github.com/login/oauth/access_token', headers={ + 'accept': 'application/json', + 'editor-version': 'Neovim/0.6.1', + 'editor-plugin-version': 'copilot.vim/1.16.0', + 'content-type': 'application/json', + 'user-agent': 'GithubCopilot/1.155.0', + 'accept-encoding': 'gzip,deflate,br' + }, data=f'{{"client_id":"Iv1.b507a08c87ecfe98","device_code":"{device_code}","grant_type":"urn:ietf:params:oauth:grant-type:device_code"}}') + + # Parse the response json, isolating the access_token + resp_json = resp.json() + access_token = resp_json.get('access_token') + + if access_token: + break + + # Save the access token to a file + with open('.copilot_token', 'w') as f: + f.write(access_token) + + print('Authentication success!') + + +def get_token(): + global token + # Check if the .copilot_token file exists + while True: + try: + with open('.copilot_token', 'r') as f: + access_token = f.read() + break + except FileNotFoundError: + setup() + # Get a session with the access token + resp = requests.get('https://api.github.com/copilot_internal/v2/token', headers={ + 'authorization': f'token {access_token}', + 'editor-version': 'Neovim/0.6.1', + 'editor-plugin-version': 'copilot.vim/1.16.0', + 'user-agent': 'GithubCopilot/1.155.0' + }) + + # Parse the response json, isolating the token + resp_json = resp.json() + token = resp_json.get('token') + + +def token_thread(): + global token + while True: + get_token() + time.sleep(25 * 60) + +def copilot(prompt, language='python'): + global token + # If the token is None, get a new one + if token is None or is_token_invalid(token): + get_token() + + try: + resp = requests.post('https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions', headers={'authorization': f'Bearer {token}'}, json={ + 'prompt': prompt, + 'suffix': '', + 'max_tokens': 1000, + 'temperature': 0, + 'top_p': 1, + 'n': 1, + 'stop': ['\n'], + 'nwo': 'github/copilot.vim', + 'stream': True, + 'extra': { + 'language': language + } + }) + except requests.exceptions.ConnectionError: + return '' + + result = '' + + # Parse the response text, splitting it by newlines + resp_text = resp.text.split('\n') + for line in resp_text: + # If the line contains a completion, print it + if line.startswith('data: {'): + # Parse the completion from the line as json + json_completion = json.loads(line[6:]) + completion = json_completion.get('choices')[0].get('text') + if completion: + result += completion + else: + result += '\n' + + return result + +# Check if the token is invalid through the exp field +def is_token_invalid(token): + if token is None or 'exp' not in token or extract_exp_value(token) <= time.time(): + return True + return False + +def extract_exp_value(token): + pairs = token.split(';') + for pair in pairs: + key, value = pair.split('=') + if key.strip() == 'exp': + return int(value.strip()) + return None + +class HTTPRequestHandler(http.server.BaseHTTPRequestHandler): + def do_POST(self): + # Get the request body + content_length = int(self.headers['Content-Length']) + body = self.rfile.read(content_length) + + # Parse the request body as json + body_json = json.loads(body) + + # Get the prompt from the request body + prompt = body_json.get('prompt') + language = body_json.get('language', 'python') + + # Get the completion from the copilot function + completion = copilot(prompt, language) + + # Send the completion as the response + self.send_response(200) + self.send_header('Content-type', 'text/plain') + self.end_headers() + self.wfile.write(completion.encode()) + + +def main(): + # Every 25 minutes, get a new token + threading.Thread(target=token_thread).start() + # Get the port to listen on from the command line + if len(sys.argv) < 2: + port = 8080 + else: + port = int(sys.argv[1]) + # Start the http server + httpd = http.server.HTTPServer(('0.0.0.0', port), HTTPRequestHandler) + print(f'Listening on port 0.0.0.0:{port}...') + httpd.serve_forever() + +if __name__ == '__main__': + main() diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..2d4585a --- /dev/null +++ b/test_main.py @@ -0,0 +1,90 @@ +import unittest +from unittest.mock import patch, mock_open, MagicMock +import main +import time +import json +import http.server +import threading + +class TestMain(unittest.TestCase): + + @patch('main.requests.post') + @patch('main.open', new_callable=mock_open) + def test_setup(self, mock_file, mock_post): + mock_resp = MagicMock() + mock_resp.json.return_value = { + 'device_code': 'test_device_code', + 'user_code': 'test_user_code', + 'verification_uri': 'test_verification_uri' + } + mock_post.return_value = mock_resp + + with patch('builtins.print') as mock_print: + main.setup() + mock_print.assert_any_call('Please visit test_verification_uri and enter code test_user_code to authenticate.') + + mock_file().write.assert_called_once_with('test_access_token') + + @patch('main.requests.get') + @patch('main.open', new_callable=mock_open, read_data='test_access_token') + def test_get_token(self, mock_file, mock_get): + mock_resp = MagicMock() + mock_resp.json.return_value = {'token': 'test_token'} + mock_get.return_value = mock_resp + + main.get_token() + self.assertEqual(main.token, 'test_token') + + @patch('main.get_token') + @patch('main.time.sleep', return_value=None) + def test_token_thread(self, mock_sleep, mock_get_token): + with patch('threading.Thread.start') as mock_start: + thread = threading.Thread(target=main.token_thread) + thread.start() + mock_start.assert_called_once() + + @patch('main.requests.post') + @patch('main.is_token_invalid', return_value=True) + @patch('main.get_token') + def test_copilot(self, mock_get_token, mock_is_token_invalid, mock_post): + mock_resp = MagicMock() + mock_resp.text = 'data: {"choices":[{"text":"test_completion"}]}' + mock_post.return_value = mock_resp + + result = main.copilot('test_prompt') + self.assertEqual(result, 'test_completion') + + def test_is_token_invalid(self): + valid_token = 'exp=9999999999' + invalid_token = 'exp=0' + self.assertFalse(main.is_token_invalid(valid_token)) + self.assertTrue(main.is_token_invalid(invalid_token)) + + def test_extract_exp_value(self): + token = 'key1=value1; exp=1234567890; key2=value2' + self.assertEqual(main.extract_exp_value(token), 1234567890) + + @patch('main.copilot', return_value='test_completion') + def test_HTTPRequestHandler(self, mock_copilot): + handler = main.HTTPRequestHandler + handler.rfile = MagicMock() + handler.rfile.read.return_value = json.dumps({'prompt': 'test_prompt', 'language': 'python'}).encode() + handler.headers = {'Content-Length': len(handler.rfile.read.return_value)} + handler.wfile = MagicMock() + + handler.do_POST(handler) + + handler.wfile.write.assert_called_once_with(b'test_completion') + + @patch('main.threading.Thread.start') + @patch('main.http.server.HTTPServer.serve_forever') + def test_main(self, mock_serve_forever, mock_thread_start): + with patch('builtins.print') as mock_print: + with patch('sys.argv', ['main.py', '8080']): + main.main() + mock_print.assert_any_call('Listening on port 0.0.0.0:8080...') + mock_thread_start.assert_called_once() + mock_serve_forever.assert_called_once() + +if __name__ == '__main__': + unittest.main()