1212import json
1313import os
1414import time
15- ENGINE = os .environ .get ("GPT_ENGINE" ) or "gpt-3.5-turbo"
1615ENCODER = tiktoken .get_encoding ("gpt2" )
1716class chatPaper :
1817 """
@@ -21,22 +20,20 @@ class chatPaper:
2120 def __init__ (
2221 self ,
2322 api_keys : list ,
24- engine = None ,
2523 proxy = None ,
2624 api_proxy = None ,
2725 max_tokens : int = 4000 ,
2826 temperature : float = 0.5 ,
2927 top_p : float = 1.0 ,
28+ model_name : str = "gpt-3.5-turbo" ,
3029 reply_count : int = 1 ,
3130 system_prompt = "You are ChatPaper, A paper reading bot" ,
3231 lastAPICallTime = time .time ()- 100 ,
3332 apiTimeInterval = 20 ,
34- maxBackup = 10 ,
3533 ) -> None :
36- self .maxBackup = maxBackup
34+ self .model_name = model_name
3735 self .system_prompt = system_prompt
3836 self .apiTimeInterval = apiTimeInterval
39- self .engine = engine or ENGINE
4037 self .session = requests .Session ()
4138 self .api_keys = PQ ()
4239 for key in api_keys :
@@ -89,19 +86,19 @@ def __truncate_conversation(self, convo_id: str = "default"):
8986 while (len (ENCODER .encode (str (query )))> self .max_tokens ):
9087 query = query [:self .decrease_step ]
9188 self .conversation [convo_id ] = self .conversation [convo_id ][:- 1 ]
92- full_conversation = "\n " .join ([x ["content" ] for x in self .conversation [convo_id ]],)
89+ full_conversation = "\n " .join ([str ( x ["content" ]) for x in self .conversation [convo_id ]],)
9390 if len (ENCODER .encode (full_conversation )) > self .max_tokens :
9491 self .conversation_summary (convo_id = convo_id )
95- last_dialog ['content' ] = query
96- self .conversation [convo_id ].append (last_dialog )
92+ full_conversation = ""
93+ for x in self .conversation [convo_id ]:
94+ full_conversation = str (x ["content" ]) + "\n " + full_conversation
9795 while True :
98- full_conversation = ""
99- for x in self .conversation [convo_id ]:
100- full_conversation = x ["content" ] + "\n "
101- if (len (ENCODER .encode (full_conversation )) > self .max_tokens ):
102- self .conversation [convo_id ][- 1 ] = self .conversation [convo_id ][- 1 ][:- self .decrease_step ]
96+ if (len (ENCODER .encode (full_conversation + query )) > self .max_tokens ):
97+ query = query [:self .decrease_step ]
10398 else :
10499 break
100+ last_dialog ['content' ] = str (query )
101+ self .conversation [convo_id ].append (last_dialog )
105102
106103 def ask_stream (
107104 self ,
@@ -119,7 +116,7 @@ def ask_stream(
119116 "https://api.openai.com/v1/chat/completions" ,
120117 headers = {"Authorization" : f"Bearer { kwargs .get ('api_key' , apiKey )} " },
121118 json = {
122- "model" : self .engine ,
119+ "model" : self .model_name ,
123120 "messages" : self .conversation [convo_id ],
124121 "stream" : True ,
125122 # kwargs
@@ -129,7 +126,7 @@ def ask_stream(
129126 "user" : role ,
130127 },
131128 stream = True ,
132- )
129+ )
133130 if response .status_code != 200 :
134131 raise Exception (
135132 f"Error: { response .status_code } { response .reason } { response .text } " ,
@@ -163,9 +160,31 @@ def ask(self, prompt: str, role: str = "user", convo_id: str = "default", **kwar
163160 )
164161 full_response : str = "" .join (response )
165162 self .add_to_conversation (full_response , role , convo_id = convo_id )
166- return full_response
167-
163+ usage_token = self .token_str (prompt )
164+ com_token = self .token_str (full_response )
165+ total_token = self .token_cost (convo_id = convo_id )
166+ return full_response , usage_token , com_token , total_token
168167
168+ def check_api_available (self ):
169+ response = self .session .post (
170+ "https://api.openai.com/v1/chat/completions" ,
171+ headers = {"Authorization" : f"Bearer { self .get_api_key ()} " },
172+ json = {
173+ "model" : self .engine ,
174+ "messages" : [{"role" : "system" , "content" : "You are a helpful assistant." },{"role" : "user" , "content" : "print A" }],
175+ "stream" : True ,
176+ # kwargs
177+ "temperature" : self .temperature ,
178+ "top_p" : self .top_p ,
179+ "n" : self .reply_count ,
180+ "user" : "user" ,
181+ },
182+ stream = True ,
183+ )
184+ if response .status_code == 200 :
185+ return True
186+ else :
187+ return False
169188 def reset (self , convo_id : str = "default" , system_prompt = None ):
170189 """
171190 Reset the conversation
0 commit comments