BaiduChatUtils.java 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package com.goafanti.common.utils;
  2. import com.alibaba.fastjson.JSON;
  3. import com.alibaba.fastjson.JSONObject;
  4. import com.goafanti.baiduAI.BaiduChatErrorEnums;
  5. import com.goafanti.baiduAI.bo.*;
  6. import com.goafanti.common.error.BusinessException;
  7. import okhttp3.*;
  8. import org.springframework.beans.factory.annotation.Autowired;
  9. import org.springframework.beans.factory.annotation.Value;
  10. import org.springframework.scheduling.annotation.Async;
  11. import org.springframework.stereotype.Component;
  12. import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
  13. import java.io.BufferedReader;
  14. import java.io.IOException;
  15. import java.io.InputStream;
  16. import java.io.InputStreamReader;
  17. import java.net.HttpURLConnection;
  18. import java.net.MalformedURLException;
  19. import java.net.ProtocolException;
  20. import java.net.URL;
  21. import java.nio.charset.StandardCharsets;
  22. import java.util.Calendar;
  23. import java.util.HashMap;
  24. import java.util.Map;
  25. import java.util.concurrent.TimeUnit;
  26. public class BaiduChatUtils {
  27. @Value(value = "${baidu.ApiKey}")
  28. private String baiduApiKey=null;
  29. @Value(value = "${baidu.SecretKey}")
  30. private String baiduSecretKey=null;
  31. @Autowired
  32. private RedisUtil redisUtil;
  33. /*文心一言地址*/
  34. private static final String BAIDU_CHAT_WXYY_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=";
  35. /*Ernie-Lite地址*/
  36. private static final String BAIDU_CHAT_ERNIE_LITE_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=";
  37. /*accessToken获取地址*/
  38. private static final String BAIDU_ACCESSTOKEN_URL="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&";
  39. static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().connectTimeout(120000, TimeUnit.MILLISECONDS)
  40. .readTimeout(120000, TimeUnit.MILLISECONDS)
  41. .build();
  42. public String getBaiduAccessToken() throws IOException {
  43. MediaType mediaType = MediaType.parse("application/json");
  44. RequestBody body = RequestBody.create(mediaType, "");
  45. StringBuffer url= new StringBuffer(BAIDU_ACCESSTOKEN_URL)
  46. .append("client_id=").append(baiduApiKey).append("&client_secret=").append(baiduSecretKey);
  47. Request request = new Request.Builder()
  48. .url(url.toString())
  49. .method("POST", body)
  50. .addHeader("Content-Type", "application/json")
  51. .addHeader("Accept", "application/json")
  52. .build();
  53. Response response = HTTP_CLIENT.newCall(request).execute();
  54. String result=response.body().string();
  55. HashMap<String,Object> map = JSON.parseObject(result, HashMap.class);
  56. String accessToken = map.get("access_token").toString();
  57. LoggerUtils.debug(getClass(),"获取accessToken="+accessToken);
  58. return accessToken;
  59. }
  60. public String sendBaiduAI(InputSendChat in) throws IOException{
  61. String accessToken = getRedisBaiduAccessToken();
  62. MediaType mediaType = MediaType.parse("application/json");
  63. RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(in));
  64. Request request = new Request.Builder()
  65. .url(BAIDU_CHAT_WXYY_URL + accessToken)
  66. .method("POST", body)
  67. .addHeader("Content-Type", "application/json")
  68. .build();
  69. Response response = HTTP_CLIENT.newCall(request).execute();
  70. String result=response.body().string();
  71. return result;
  72. }
  73. @Async
  74. public void sendBaiduAiStream(InputSendChat inputSendChat) throws IOException {
  75. try {
  76. String baidu_url=BAIDU_CHAT_WXYY_URL+getRedisBaiduAccessToken();
  77. URL url = new URL(baidu_url);
  78. HttpURLConnection connection = (HttpURLConnection) url.openConnection();
  79. connection.setRequestMethod("POST");
  80. connection.setRequestProperty("Content-Type", "application/json");
  81. connection.setDoInput(true);
  82. connection.setDoOutput(true);
  83. // 构造请求体
  84. // String requestBody = "{\"messages\":[{\"role\":\"user\",\"content\":\"给我介绍一条从四川自驾到拉萨的路线\"}],\"stream\":true}";
  85. String requestBody=JSON.toJSONString(inputSendChat);
  86. byte[] postData = requestBody.getBytes(StandardCharsets.UTF_8);
  87. connection.setRequestProperty("Content-Length", String.valueOf(postData.length));
  88. connection.getOutputStream().write(postData);
  89. InputStream responseStream = connection.getInputStream();
  90. BufferedReader reader = new BufferedReader(new InputStreamReader(responseStream,"UTF-8"));
  91. String line;
  92. SseResult res = SseMap.sseEmitterMap.get(inputSendChat.getUserId());
  93. while ((line = reader.readLine())!= null) {
  94. // 每行数据中以 "data:" 开头的部分即为实际的响应数据
  95. System.out.println("relut="+line+"。");
  96. if (line.startsWith("data:")) {
  97. String data = line.substring("data:".length()).trim();
  98. JSONObject jsonObject = JSONObject.parseObject(data);
  99. Boolean isEnd = jsonObject.getBoolean("is_end");
  100. if(isEnd){
  101. break;
  102. }
  103. // OutSendChatOK out =jsonObject.toJavaObject(OutSendChatOK.class);
  104. // System.out.println(data);
  105. res.sseEmitter.send(data);
  106. }else if(line.startsWith("{")) {
  107. JSONObject jsonObject = JSONObject.parseObject(line);
  108. jsonObject.getString("error_code");
  109. res.sseEmitter.send(line);
  110. }
  111. }
  112. res.sseEmitter.complete();
  113. reader.close();
  114. } catch (MalformedURLException e) {
  115. e.printStackTrace();
  116. sendJitaoBaiWen(inputSendChat.getUserId(),"data:{\"error_code\":\"2\",\"error_msg\":\"域名解析异常\"}");
  117. } catch (ProtocolException e) {
  118. e.printStackTrace();
  119. }
  120. SseMap.sseEmitterMap.remove(inputSendChat.getUserId());
  121. }
  122. private OutSendChat pushResultToOutSendChat(String result) {
  123. Map<String ,Object> resultMap=JSON.parseObject(result, Map.class);
  124. Integer errorCode= (Integer) resultMap.get("error_code");
  125. if (errorCode!=null){
  126. OutChatER res=new OutChatER();
  127. res.setErrorCode(errorCode);
  128. if (errorCode.equals("336003")){
  129. res.setErrorMsg(BaiduChatErrorEnums.BycodeGetMsg(errorCode)+resultMap.get("error_msg"));
  130. }else {
  131. res.setErrorMsg(BaiduChatErrorEnums.BycodeGetMsg(errorCode));
  132. }
  133. return res;
  134. }else {
  135. OutSendChatOK res=new OutSendChatOK();
  136. res=JSON.parseObject(result,OutSendChatOK.class);
  137. return res;
  138. }
  139. }
  140. private void sendJitaoBaiWen(String userId, String s) throws IOException {
  141. SseResult res = SseMap.sseEmitterMap.get(userId);
  142. res.sseEmitter.send(s);
  143. }
  144. private String getRedisBaiduAccessToken() {
  145. String redisAccessToken=null;
  146. String redisTime=redisUtil.getString("baiduAccessTime");
  147. //没有
  148. if (redisTime !=null){
  149. Calendar cal = Calendar.getInstance();
  150. Long redisAccessTime=Long.valueOf(redisTime);
  151. if (cal.getTimeInMillis()>redisAccessTime){
  152. redisUtil.deleteString("baiduAccessToken");
  153. redisUtil.deleteString("baiduAccessTime");
  154. redisAccessToken=pushRedisBaiduAccessToken();
  155. LoggerUtils.debug(getClass(),"accessToken过期,重新获取!");
  156. }else {
  157. redisAccessToken= redisUtil.getString("baiduAccessToken");
  158. LoggerUtils.debug(getClass(),"accessToken从redis获取成功!");
  159. }
  160. }else {
  161. redisAccessToken=pushRedisBaiduAccessToken();
  162. LoggerUtils.debug(getClass(),"accessToken不存在,从百度获取!");
  163. }
  164. return redisAccessToken;
  165. }
  166. private String pushRedisBaiduAccessToken() {
  167. String baiduAccessToken;
  168. try {
  169. baiduAccessToken = getBaiduAccessToken();
  170. } catch (IOException e) {
  171. throw new BusinessException("baiduAccessToken获取失败");
  172. }
  173. //获取当前系统时间
  174. Calendar cal = Calendar.getInstance();
  175. //将时间增加三十天
  176. cal.add(Calendar.DATE, 30);
  177. //获取改变后的时间
  178. Long baiduAccessTime= cal.getTimeInMillis();
  179. redisUtil.setString("baiduAccessToken",baiduAccessToken);
  180. redisUtil.setString("baiduAccessTime",baiduAccessTime.toString());
  181. return baiduAccessToken;
  182. }
  183. }