mirror of
https://github.com/elder-plinius/L1B3RT4S.git
synced 2025-09-26 02:33:39 +02:00
Merge 4b4925bff7
into 3624426245
This commit is contained in:
commit
8666148cb2
3 changed files with 273 additions and 0 deletions
6
.devcontainer/devcontainer.json
Normal file
6
.devcontainer/devcontainer.json
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"tasks": {
|
||||||
|
"build": "echo \"No build steps required for this repository.\"",
|
||||||
|
"test": "pip install pytest && pytest"
|
||||||
|
}
|
||||||
|
}
|
177
main.py
Normal file
177
main.py
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
import threading
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import http.server
|
||||||
|
|
||||||
|
token = None
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
resp = requests.post('https://github.com/login/device/code', headers={
|
||||||
|
'accept': 'application/json',
|
||||||
|
'editor-version': 'Neovim/0.6.1',
|
||||||
|
'editor-plugin-version': 'copilot.vim/1.16.0',
|
||||||
|
'content-type': 'application/json',
|
||||||
|
'user-agent': 'GithubCopilot/1.155.0',
|
||||||
|
'accept-encoding': 'gzip,deflate,br'
|
||||||
|
}, data='{"client_id":"Iv1.b507a08c87ecfe98","scope":"read:user"}')
|
||||||
|
|
||||||
|
|
||||||
|
# Parse the response json, isolating the device_code, user_code, and verification_uri
|
||||||
|
resp_json = resp.json()
|
||||||
|
device_code = resp_json.get('device_code')
|
||||||
|
user_code = resp_json.get('user_code')
|
||||||
|
verification_uri = resp_json.get('verification_uri')
|
||||||
|
|
||||||
|
# Print the user code and verification uri
|
||||||
|
print(f'Please visit {verification_uri} and enter code {user_code} to authenticate.')
|
||||||
|
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(5)
|
||||||
|
resp = requests.post('https://github.com/login/oauth/access_token', headers={
|
||||||
|
'accept': 'application/json',
|
||||||
|
'editor-version': 'Neovim/0.6.1',
|
||||||
|
'editor-plugin-version': 'copilot.vim/1.16.0',
|
||||||
|
'content-type': 'application/json',
|
||||||
|
'user-agent': 'GithubCopilot/1.155.0',
|
||||||
|
'accept-encoding': 'gzip,deflate,br'
|
||||||
|
}, data=f'{{"client_id":"Iv1.b507a08c87ecfe98","device_code":"{device_code}","grant_type":"urn:ietf:params:oauth:grant-type:device_code"}}')
|
||||||
|
|
||||||
|
# Parse the response json, isolating the access_token
|
||||||
|
resp_json = resp.json()
|
||||||
|
access_token = resp_json.get('access_token')
|
||||||
|
|
||||||
|
if access_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save the access token to a file
|
||||||
|
with open('.copilot_token', 'w') as f:
|
||||||
|
f.write(access_token)
|
||||||
|
|
||||||
|
print('Authentication success!')
|
||||||
|
|
||||||
|
|
||||||
|
def get_token():
|
||||||
|
global token
|
||||||
|
# Check if the .copilot_token file exists
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with open('.copilot_token', 'r') as f:
|
||||||
|
access_token = f.read()
|
||||||
|
break
|
||||||
|
except FileNotFoundError:
|
||||||
|
setup()
|
||||||
|
# Get a session with the access token
|
||||||
|
resp = requests.get('https://api.github.com/copilot_internal/v2/token', headers={
|
||||||
|
'authorization': f'token {access_token}',
|
||||||
|
'editor-version': 'Neovim/0.6.1',
|
||||||
|
'editor-plugin-version': 'copilot.vim/1.16.0',
|
||||||
|
'user-agent': 'GithubCopilot/1.155.0'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Parse the response json, isolating the token
|
||||||
|
resp_json = resp.json()
|
||||||
|
token = resp_json.get('token')
|
||||||
|
|
||||||
|
|
||||||
|
def token_thread():
|
||||||
|
global token
|
||||||
|
while True:
|
||||||
|
get_token()
|
||||||
|
time.sleep(25 * 60)
|
||||||
|
|
||||||
|
def copilot(prompt, language='python'):
|
||||||
|
global token
|
||||||
|
# If the token is None, get a new one
|
||||||
|
if token is None or is_token_invalid(token):
|
||||||
|
get_token()
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = requests.post('https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions', headers={'authorization': f'Bearer {token}'}, json={
|
||||||
|
'prompt': prompt,
|
||||||
|
'suffix': '',
|
||||||
|
'max_tokens': 1000,
|
||||||
|
'temperature': 0,
|
||||||
|
'top_p': 1,
|
||||||
|
'n': 1,
|
||||||
|
'stop': ['\n'],
|
||||||
|
'nwo': 'github/copilot.vim',
|
||||||
|
'stream': True,
|
||||||
|
'extra': {
|
||||||
|
'language': language
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
result = ''
|
||||||
|
|
||||||
|
# Parse the response text, splitting it by newlines
|
||||||
|
resp_text = resp.text.split('\n')
|
||||||
|
for line in resp_text:
|
||||||
|
# If the line contains a completion, print it
|
||||||
|
if line.startswith('data: {'):
|
||||||
|
# Parse the completion from the line as json
|
||||||
|
json_completion = json.loads(line[6:])
|
||||||
|
completion = json_completion.get('choices')[0].get('text')
|
||||||
|
if completion:
|
||||||
|
result += completion
|
||||||
|
else:
|
||||||
|
result += '\n'
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Check if the token is invalid through the exp field
|
||||||
|
def is_token_invalid(token):
|
||||||
|
if token is None or 'exp' not in token or extract_exp_value(token) <= time.time():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def extract_exp_value(token):
|
||||||
|
pairs = token.split(';')
|
||||||
|
for pair in pairs:
|
||||||
|
key, value = pair.split('=')
|
||||||
|
if key.strip() == 'exp':
|
||||||
|
return int(value.strip())
|
||||||
|
return None
|
||||||
|
|
||||||
|
class HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
# Get the request body
|
||||||
|
content_length = int(self.headers['Content-Length'])
|
||||||
|
body = self.rfile.read(content_length)
|
||||||
|
|
||||||
|
# Parse the request body as json
|
||||||
|
body_json = json.loads(body)
|
||||||
|
|
||||||
|
# Get the prompt from the request body
|
||||||
|
prompt = body_json.get('prompt')
|
||||||
|
language = body_json.get('language', 'python')
|
||||||
|
|
||||||
|
# Get the completion from the copilot function
|
||||||
|
completion = copilot(prompt, language)
|
||||||
|
|
||||||
|
# Send the completion as the response
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-type', 'text/plain')
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(completion.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Every 25 minutes, get a new token
|
||||||
|
threading.Thread(target=token_thread).start()
|
||||||
|
# Get the port to listen on from the command line
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
port = 8080
|
||||||
|
else:
|
||||||
|
port = int(sys.argv[1])
|
||||||
|
# Start the http server
|
||||||
|
httpd = http.server.HTTPServer(('0.0.0.0', port), HTTPRequestHandler)
|
||||||
|
print(f'Listening on port 0.0.0.0:{port}...')
|
||||||
|
httpd.serve_forever()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
90
test_main.py
Normal file
90
test_main.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, mock_open, MagicMock
|
||||||
|
import main
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import http.server
|
||||||
|
import threading
|
||||||
|
|
||||||
|
class TestMain(unittest.TestCase):
|
||||||
|
|
||||||
|
@patch('main.requests.post')
|
||||||
|
@patch('main.open', new_callable=mock_open)
|
||||||
|
def test_setup(self, mock_file, mock_post):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
'device_code': 'test_device_code',
|
||||||
|
'user_code': 'test_user_code',
|
||||||
|
'verification_uri': 'test_verification_uri'
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_resp
|
||||||
|
|
||||||
|
with patch('builtins.print') as mock_print:
|
||||||
|
main.setup()
|
||||||
|
mock_print.assert_any_call('Please visit test_verification_uri and enter code test_user_code to authenticate.')
|
||||||
|
|
||||||
|
mock_file().write.assert_called_once_with('test_access_token')
|
||||||
|
|
||||||
|
@patch('main.requests.get')
|
||||||
|
@patch('main.open', new_callable=mock_open, read_data='test_access_token')
|
||||||
|
def test_get_token(self, mock_file, mock_get):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {'token': 'test_token'}
|
||||||
|
mock_get.return_value = mock_resp
|
||||||
|
|
||||||
|
main.get_token()
|
||||||
|
self.assertEqual(main.token, 'test_token')
|
||||||
|
|
||||||
|
@patch('main.get_token')
|
||||||
|
@patch('main.time.sleep', return_value=None)
|
||||||
|
def test_token_thread(self, mock_sleep, mock_get_token):
|
||||||
|
with patch('threading.Thread.start') as mock_start:
|
||||||
|
thread = threading.Thread(target=main.token_thread)
|
||||||
|
thread.start()
|
||||||
|
mock_start.assert_called_once()
|
||||||
|
|
||||||
|
@patch('main.requests.post')
|
||||||
|
@patch('main.is_token_invalid', return_value=True)
|
||||||
|
@patch('main.get_token')
|
||||||
|
def test_copilot(self, mock_get_token, mock_is_token_invalid, mock_post):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.text = 'data: {"choices":[{"text":"test_completion"}]}'
|
||||||
|
mock_post.return_value = mock_resp
|
||||||
|
|
||||||
|
result = main.copilot('test_prompt')
|
||||||
|
self.assertEqual(result, 'test_completion')
|
||||||
|
|
||||||
|
def test_is_token_invalid(self):
|
||||||
|
valid_token = 'exp=9999999999'
|
||||||
|
invalid_token = 'exp=0'
|
||||||
|
self.assertFalse(main.is_token_invalid(valid_token))
|
||||||
|
self.assertTrue(main.is_token_invalid(invalid_token))
|
||||||
|
|
||||||
|
def test_extract_exp_value(self):
|
||||||
|
token = 'key1=value1; exp=1234567890; key2=value2'
|
||||||
|
self.assertEqual(main.extract_exp_value(token), 1234567890)
|
||||||
|
|
||||||
|
@patch('main.copilot', return_value='test_completion')
|
||||||
|
def test_HTTPRequestHandler(self, mock_copilot):
|
||||||
|
handler = main.HTTPRequestHandler
|
||||||
|
handler.rfile = MagicMock()
|
||||||
|
handler.rfile.read.return_value = json.dumps({'prompt': 'test_prompt', 'language': 'python'}).encode()
|
||||||
|
handler.headers = {'Content-Length': len(handler.rfile.read.return_value)}
|
||||||
|
handler.wfile = MagicMock()
|
||||||
|
|
||||||
|
handler.do_POST(handler)
|
||||||
|
|
||||||
|
handler.wfile.write.assert_called_once_with(b'test_completion')
|
||||||
|
|
||||||
|
@patch('main.threading.Thread.start')
|
||||||
|
@patch('main.http.server.HTTPServer.serve_forever')
|
||||||
|
def test_main(self, mock_serve_forever, mock_thread_start):
|
||||||
|
with patch('builtins.print') as mock_print:
|
||||||
|
with patch('sys.argv', ['main.py', '8080']):
|
||||||
|
main.main()
|
||||||
|
mock_print.assert_any_call('Listening on port 0.0.0.0:8080...')
|
||||||
|
mock_thread_start.assert_called_once()
|
||||||
|
mock_serve_forever.assert_called_once()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue