BaiduChatUtils.java 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package com.goafanti.common.utils;
  2. import com.alibaba.fastjson.JSON;
  3. import com.alibaba.fastjson.JSONObject;
  4. import com.goafanti.baiduAI.BaiduChatErrorEnums;
  5. import com.goafanti.baiduAI.bo.*;
  6. import com.goafanti.common.error.BusinessException;
  7. import okhttp3.*;
  8. import org.springframework.beans.factory.annotation.Autowired;
  9. import org.springframework.beans.factory.annotation.Value;
  10. import org.springframework.scheduling.annotation.Async;
  11. import java.io.*;
  12. import java.net.HttpURLConnection;
  13. import java.net.MalformedURLException;
  14. import java.net.ProtocolException;
  15. import java.net.URL;
  16. import java.nio.charset.StandardCharsets;
  17. import java.util.Calendar;
  18. import java.util.HashMap;
  19. import java.util.concurrent.TimeUnit;
  20. public class BaiduChatUtils {
  21. @Value(value = "${baidu.ApiKey}")
  22. private String baiduApiKey=null;
  23. @Value(value = "${baidu.SecretKey}")
  24. private String baiduSecretKey=null;
  25. @Autowired
  26. private RedisUtil redisUtil;
  27. /*文心一言地址*/
  28. private static final String BAIDU_CHAT_WXYY_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=";
  29. /*Ernie-Lite地址*/
  30. private static final String BAIDU_CHAT_ERNIE_LITE_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=";
  31. /*accessToken获取地址*/
  32. private static final String BAIDU_ACCESSTOKEN_URL="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&";
  33. static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().connectTimeout(120000, TimeUnit.MILLISECONDS)
  34. .readTimeout(120000, TimeUnit.MILLISECONDS)
  35. .build();
  36. @Async
  37. public void sendBaiduAiStream(InputSendChat inputSendChat) throws IOException {
  38. SseResult res = null;
  39. BufferedReader reader=null;
  40. InputStreamReader inputStreamReader=null;
  41. OutputStream outputStream=null;
  42. try {
  43. String baidu_url=BAIDU_CHAT_WXYY_URL+getRedisBaiduAccessToken();
  44. URL url = new URL(baidu_url);
  45. HttpURLConnection connection = (HttpURLConnection) url.openConnection();
  46. connection.setRequestMethod("POST");
  47. connection.setRequestProperty("Content-Type", "application/json");
  48. connection.setDoInput(true);
  49. connection.setDoOutput(true);
  50. // 构造请求体
  51. // String requestBody = "{\"messages\":[{\"role\":\"user\",\"content\":\"给我介绍一条从四川自驾到拉萨的路线\"}],\"stream\":true}";
  52. String requestBody=JSON.toJSONString(inputSendChat);
  53. byte[] postData = requestBody.getBytes(StandardCharsets.UTF_8);
  54. connection.setRequestProperty("Content-Length", String.valueOf(postData.length));
  55. outputStream =connection.getOutputStream();
  56. outputStream.write(postData);
  57. InputStream responseStream = connection.getInputStream();
  58. inputStreamReader = new InputStreamReader(responseStream, "UTF-8");
  59. reader = new BufferedReader(inputStreamReader);
  60. String line;
  61. res = SseMap.sseEmitterMap.get(inputSendChat.getUserId());
  62. while ((line = reader.readLine())!= null) {
  63. // 每行数据中以 "data:" 开头的部分即为实际的响应数据
  64. if (StringUtils.isNotBlank(line)){
  65. if (line.startsWith("data:")) {
  66. String data = line.substring("data:".length()).trim();
  67. JSONObject jsonObject = JSONObject.parseObject(data);
  68. Boolean isEnd = jsonObject.getBoolean("is_end");
  69. res.sseEmitter.send(data);
  70. if(isEnd){
  71. break;
  72. }
  73. //错误返回格式{"error_code":110,"error_msg":"Access token invalid or no longer valid"}
  74. }else if(line.startsWith("{")) {
  75. JSONObject jsonObject = JSONObject.parseObject(line);
  76. Integer errorCode = jsonObject.getInteger("error_code");
  77. OutChatER out=new OutChatER();
  78. if (errorCode!=null){
  79. out.setError_code(errorCode);
  80. if (errorCode.equals("336003")){
  81. out.setError_msg(BaiduChatErrorEnums.BycodeGetMsg(errorCode)+jsonObject.getString("error_msg"));
  82. }else {
  83. out.setError_msg(BaiduChatErrorEnums.BycodeGetMsg(errorCode));
  84. }
  85. }
  86. String errorStr=JSON.toJSONString(out);
  87. LoggerUtils.debug(getClass(),errorStr);
  88. res.sseEmitter.send(errorStr);
  89. }
  90. }
  91. }
  92. res.sseEmitter.complete();
  93. } catch (MalformedURLException e) {
  94. e.printStackTrace();
  95. sendJitaoBaiWen(inputSendChat.getUserId(),"data:{\"error_code\":\"2\",\"error_msg\":\"域名解析异常\"}");
  96. } catch (ProtocolException e) {
  97. e.printStackTrace();
  98. }catch (IllegalStateException e){
  99. LoggerUtils.debug(getClass(),"前端网页已关闭");
  100. }finally {
  101. inputStreamReader.close();
  102. outputStream.close();
  103. reader.close();
  104. res.sseEmitter.complete();
  105. SseMap.sseEmitterMap.remove(inputSendChat.getUserId());
  106. }
  107. }
  108. public String getBaiduAccessToken() throws IOException {
  109. MediaType mediaType = MediaType.parse("application/json");
  110. RequestBody body = RequestBody.create(mediaType, "");
  111. StringBuffer url= new StringBuffer(BAIDU_ACCESSTOKEN_URL)
  112. .append("client_id=").append(baiduApiKey).append("&client_secret=").append(baiduSecretKey);
  113. Request request = new Request.Builder()
  114. .url(url.toString())
  115. .method("POST", body)
  116. .addHeader("Content-Type", "application/json")
  117. .addHeader("Accept", "application/json")
  118. .build();
  119. Response response = HTTP_CLIENT.newCall(request).execute();
  120. String result=response.body().string();
  121. HashMap<String,Object> map = JSON.parseObject(result, HashMap.class);
  122. String accessToken = map.get("access_token").toString();
  123. LoggerUtils.debug(getClass(),"获取accessToken="+accessToken);
  124. return accessToken;
  125. }
  126. private void sendJitaoBaiWen(String userId, String s) throws IOException {
  127. SseResult res = SseMap.sseEmitterMap.get(userId);
  128. res.sseEmitter.send(s);
  129. }
  130. private String getRedisBaiduAccessToken() {
  131. String redisAccessToken;
  132. String redisTime=redisUtil.getString("baiduAccessTime");
  133. if (redisTime !=null){
  134. Calendar cal = Calendar.getInstance();
  135. Long redisAccessTime=Long.valueOf(redisTime);
  136. if (cal.getTimeInMillis()>redisAccessTime){
  137. redisUtil.deleteString("baiduAccessToken");
  138. redisUtil.deleteString("baiduAccessTime");
  139. redisAccessToken=pushRedisBaiduAccessToken();
  140. LoggerUtils.debug(getClass(),"accessToken过期,重新获取!");
  141. }else {
  142. redisAccessToken= redisUtil.getString("baiduAccessToken");
  143. }
  144. }else {
  145. redisAccessToken=pushRedisBaiduAccessToken();
  146. LoggerUtils.debug(getClass(),"accessToken不存在,从百度获取!");
  147. }
  148. return redisAccessToken;
  149. }
  150. private String pushRedisBaiduAccessToken() {
  151. String baiduAccessToken;
  152. try {
  153. baiduAccessToken = getBaiduAccessToken();
  154. } catch (IOException e) {
  155. throw new BusinessException("baiduAccessToken获取失败");
  156. }
  157. //获取当前系统时间
  158. Calendar cal = Calendar.getInstance();
  159. //将时间增加三十天
  160. cal.add(Calendar.DATE, 30);
  161. //获取改变后的时间
  162. Long baiduAccessTime= cal.getTimeInMillis();
  163. redisUtil.deleteString("baiduAccessToken");
  164. redisUtil.deleteString("baiduAccessTime");
  165. redisUtil.setString("baiduAccessToken",baiduAccessToken);
  166. redisUtil.setString("baiduAccessTime",baiduAccessTime.toString());
  167. return baiduAccessToken;
  168. }
  169. }