mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-14 12:23:39 +00:00
v1.0.3
This commit is contained in:
@@ -19,6 +19,7 @@ import java.util.List;
|
||||
@RequestMapping("/mj/account")
|
||||
@RequiredArgsConstructor
|
||||
public class AccountController {
|
||||
|
||||
private final DiscordLoadBalancer loadBalancer;
|
||||
|
||||
@ApiOperation(value = "指定ID获取账号")
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.xmzs.common.chat.constant.OpenAIConst;
|
||||
import com.xmzs.common.core.domain.model.LoginUser;
|
||||
import com.xmzs.common.core.exception.base.BaseException;
|
||||
import com.xmzs.common.satoken.utils.LoginHelper;
|
||||
import com.xmzs.midjourney.domain.InsightFace;
|
||||
import com.xmzs.system.domain.bo.ChatMessageBo;
|
||||
import com.xmzs.system.service.IChatMessageService;
|
||||
import com.xmzs.system.service.IChatService;
|
||||
import com.xmzs.system.service.ISseService;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.MediaType;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.Response;
|
||||
import okio.Buffer;
|
||||
import okio.BufferedSink;
|
||||
import okio.GzipSink;
|
||||
import okio.Okio;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Api(tags = "任务查询")
|
||||
@RestController
|
||||
@RequestMapping("/mj")
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class FaceController {
|
||||
|
||||
@Value("${chat.apiKey}")
|
||||
private String apiKey;
|
||||
@Value("${chat.apiHost}")
|
||||
private String apiHost;
|
||||
|
||||
@Autowired
|
||||
private IChatService chatService;
|
||||
|
||||
@Autowired
|
||||
private ISseService sseService;
|
||||
|
||||
@ApiOperation(value = "换脸")
|
||||
@PostMapping("/insight-face/swap")
|
||||
public String insightFace(@RequestBody InsightFace insightFace) {
|
||||
// 查询是否是付费用户
|
||||
sseService.checkUserGrade();
|
||||
// 扣除接口费用
|
||||
chatService.mjTaskDeduct("换脸", OpenAIConst.MJ_COST_TYPE2);
|
||||
OkHttpClient client = new OkHttpClient.Builder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS) // 连接超时时间
|
||||
.writeTimeout(30, TimeUnit.SECONDS) // 写入超时时间
|
||||
.readTimeout(30, TimeUnit.SECONDS) // 读取超时时间
|
||||
.build();
|
||||
// 创建一个Request对象来配置你的请求
|
||||
// 创建请求体(这里使用JSON作为媒体类型)
|
||||
String jsonStr = JSONUtil.toJsonStr(insightFace);
|
||||
|
||||
MediaType JSON = MediaType.get("application/json; charset=utf-8");
|
||||
okhttp3.RequestBody body = okhttp3.RequestBody.create(jsonStr, JSON);
|
||||
Buffer buffer = new Buffer();
|
||||
GzipSink gzipSink = new GzipSink(buffer);
|
||||
BufferedSink gzipBufferedSink = Okio.buffer(gzipSink);
|
||||
try {
|
||||
body.writeTo(gzipBufferedSink);
|
||||
gzipBufferedSink.close();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// 创建POST请求
|
||||
Request request = new Request.Builder()
|
||||
.header("mj-api-secret", apiKey)
|
||||
.header("Content-Encoding", "gzip")
|
||||
.url(apiHost + "mj/insight-face/swap") // 替换为你的URL
|
||||
.post(body)
|
||||
.build();
|
||||
|
||||
try (Response response = client.newCall(request).execute()) {
|
||||
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
|
||||
if (response.body() != null) {
|
||||
return response.body().string();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("换脸失败! {}", e.getMessage());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,20 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import cn.hutool.core.comparator.CompareUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.xmzs.midjourney.dto.SubmitImagineDTO;
|
||||
import com.xmzs.midjourney.dto.TaskConditionDTO;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import io.swagger.annotations.ApiParam;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
@@ -16,6 +22,7 @@ import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
@@ -25,15 +32,35 @@ import java.util.Objects;
|
||||
@RestController
|
||||
@RequestMapping("/mj/task")
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class TaskController {
|
||||
private final TaskStoreService taskStoreService;
|
||||
private final DiscordLoadBalancer discordLoadBalancer;
|
||||
|
||||
@Value("${chat.apiKey}")
|
||||
private String apiKey;
|
||||
@Value("${chat.apiHost}")
|
||||
private String apiHost;
|
||||
|
||||
@ApiOperation(value = "指定ID获取任务")
|
||||
@GetMapping("/{id}/fetch")
|
||||
public Task fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
|
||||
return this.taskStoreService.get(id);
|
||||
}
|
||||
public String fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
|
||||
OkHttpClient client = new OkHttpClient();
|
||||
// 创建一个Request对象来配置你的请求
|
||||
Request request = new Request.Builder()
|
||||
.header("mj-api-secret", apiKey) // 设置Authorization header
|
||||
.url(apiHost+"mj/task/" + id + "/fetch")
|
||||
.build();
|
||||
try (Response response = client.newCall(request).execute()) {
|
||||
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
|
||||
if (response.body() != null) {
|
||||
return response.body().string();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("任务:{}查询失败:{}",id,e.getMessage());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@ApiOperation(value = "查询任务队列")
|
||||
@GetMapping("/queue")
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.xmzs.midjourney.domain;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* @author WangLe
|
||||
*/
|
||||
@Data
|
||||
@ApiModel("Discord账号")
|
||||
public class InsightFace implements Serializable {
|
||||
/**本人头像json*/
|
||||
@ApiModelProperty("本人头像json")
|
||||
private String sourceBase64;
|
||||
|
||||
/**明星头像json*/
|
||||
@ApiModelProperty("明星头像json")
|
||||
private String targetBase64;
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import com.xmzs.midjourney.enums.TaskAction;
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
|
||||
@Data
|
||||
@ApiModel("变化任务提交参数")
|
||||
public class SubmitActionDTO {
|
||||
|
||||
private String customId;
|
||||
|
||||
private String taskId;
|
||||
|
||||
private String state;
|
||||
|
||||
private String notifyHook;
|
||||
}
|
||||
@@ -1,32 +1,152 @@
|
||||
package com.xmzs.midjourney.support;
|
||||
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.common.chat.constant.OpenAIConst;
|
||||
import com.xmzs.common.core.exception.ServiceException;
|
||||
import com.xmzs.system.service.IChatService;
|
||||
import com.xmzs.system.service.ISseService;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.http.*;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.StreamUtils;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.servlet.HandlerInterceptor;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Enumeration;
|
||||
import java.util.Objects;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class ApiAuthorizeInterceptor implements HandlerInterceptor {
|
||||
private final ProxyProperties properties;
|
||||
private static final Logger log = LoggerFactory.getLogger(ApiAuthorizeInterceptor.class);
|
||||
private static final String API_SECRET_HEADER = "mj-api-secret";
|
||||
|
||||
@Override
|
||||
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
|
||||
if (CharSequenceUtil.isBlank(this.properties.getApiSecret())) {
|
||||
return true;
|
||||
}
|
||||
String apiSecret = request.getHeader(Constants.API_SECRET_HEADER_NAME);
|
||||
boolean authorized = CharSequenceUtil.equals(apiSecret, this.properties.getApiSecret());
|
||||
if (!authorized) {
|
||||
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
|
||||
}
|
||||
return authorized;
|
||||
}
|
||||
@Value("${chat.apiKey}")
|
||||
private String API_SECRET_VALUE;
|
||||
@Value("${chat.apiHost}")
|
||||
private String apiHost;
|
||||
|
||||
@Autowired
|
||||
private RestTemplate restTemplate;
|
||||
|
||||
@Autowired
|
||||
private IChatService chatService;
|
||||
|
||||
@Autowired
|
||||
private ISseService sseService;
|
||||
|
||||
@Override
|
||||
public boolean preHandle(@NotNull HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull Object handler) {
|
||||
// 判断是否是MidJourney的请求
|
||||
if (isMidJourneyRequest(request)) {
|
||||
try {
|
||||
// 处理请求,例如费用扣除和任务查询
|
||||
processRequest(request);
|
||||
// 转发请求到目标服务器
|
||||
forwardRequest(request, response);
|
||||
} catch (Exception e) {
|
||||
// 记录错误日志,包括异常堆栈信息
|
||||
log.error("转发请求时发生错误", e);
|
||||
// 设置HTTP状态码和响应信息
|
||||
response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value());
|
||||
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
|
||||
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
|
||||
try {
|
||||
// 向客户端返回错误信息
|
||||
response.getWriter().write("绘图失败!" + e.getMessage());
|
||||
} catch (Exception ex) {
|
||||
log.error("设置错误响应时发生错误", ex);
|
||||
}
|
||||
// 中断请求处理流程
|
||||
return false;
|
||||
}
|
||||
// 中断正常的请求处理流程,因为请求已被转发
|
||||
return false;
|
||||
}
|
||||
// 如果不是MidJourney的请求,则继续正常处理
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean isMidJourneyRequest(HttpServletRequest request) {
|
||||
String uri = request.getRequestURI();
|
||||
return uri.startsWith("/mj") &&
|
||||
!uri.matches(".*/\\d+/fetch") &&
|
||||
!uri.matches("/mj/insight-face/swap") &&
|
||||
!uri.matches("/mj/submit/action");
|
||||
}
|
||||
|
||||
private void processRequest(HttpServletRequest request) {
|
||||
// 处理付费用户的请求,包括费用扣除和任务查询
|
||||
sseService.checkUserGrade();
|
||||
String uri = request.getRequestURI();
|
||||
if (uri.matches("/mj/submit/describe") || uri.matches("/mj/submit/shorten")) {
|
||||
chatService.mjTaskDeduct(uri.endsWith("describe") ? "图生文" : "prompt分析", OpenAIConst.MJ_COST_TYPE2);
|
||||
} else if (uri.endsWith("image-seed") || uri.endsWith("list-by-condition")) {
|
||||
chatService.mjTaskDeduct(uri.endsWith("image-seed") ? "获取种子" : "任务查询", OpenAIConst.MJ_COST_TYPE3);
|
||||
} else if (uri.matches("/mj/submit/.*")) {
|
||||
chatService.mjTaskDeduct("文生图", OpenAIConst.MJ_COST_TYPE1);
|
||||
}
|
||||
}
|
||||
|
||||
private void forwardRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
String targetUrl = buildTargetUrl(request);
|
||||
HttpEntity<String> entity = new HttpEntity<>(readRequestBody(request), copyHeaders(request));
|
||||
HttpMethod method = HttpMethod.valueOf(request.getMethod());
|
||||
ResponseEntity<byte[]> responseEntity = restTemplate.exchange(targetUrl, method, entity, byte[].class);
|
||||
copyResponseBack(response, responseEntity);
|
||||
}
|
||||
|
||||
private String buildTargetUrl(HttpServletRequest request) {
|
||||
String uri = request.getRequestURI();
|
||||
String queryString = request.getQueryString();
|
||||
log.info("Forwarding URL: {}", uri);
|
||||
return apiHost + uri + (queryString != null ? "?" + queryString : "");
|
||||
}
|
||||
|
||||
private HttpHeaders copyHeaders(HttpServletRequest request) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.set(API_SECRET_HEADER, API_SECRET_VALUE);
|
||||
Enumeration<String> headerNames = request.getHeaderNames();
|
||||
while (headerNames.hasMoreElements()) {
|
||||
String headerName = headerNames.nextElement();
|
||||
if (!headerName.equalsIgnoreCase(API_SECRET_HEADER) &&
|
||||
!headerName.equalsIgnoreCase(HttpHeaders.CONTENT_LENGTH) &&
|
||||
!headerName.equalsIgnoreCase(HttpHeaders.AUTHORIZATION)) {
|
||||
headers.set(headerName, request.getHeader(headerName));
|
||||
}
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
private String readRequestBody(HttpServletRequest request) throws Exception {
|
||||
if (request.getContentLengthLong() > 0) {
|
||||
try (InputStream inputStream = request.getInputStream()) {
|
||||
return StreamUtils.copyToString(inputStream, StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
private void copyResponseBack(HttpServletResponse response, ResponseEntity<byte[]> responseEntity) throws Exception {
|
||||
HttpHeaders responseHeaders = responseEntity.getHeaders();
|
||||
responseHeaders.forEach((key, values) -> {
|
||||
if (!key.equalsIgnoreCase(API_SECRET_HEADER)) {
|
||||
response.addHeader(key, String.join(",", values));
|
||||
}
|
||||
});
|
||||
// 设置响应内容类型为UTF-8,防止乱码
|
||||
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
|
||||
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
|
||||
HttpStatus status = HttpStatus.resolve(responseEntity.getStatusCode().value());
|
||||
response.setStatus(Objects.requireNonNullElse(status, HttpStatus.INTERNAL_SERVER_ERROR).value());
|
||||
if (responseEntity.getBody() != null) {
|
||||
StreamUtils.copy(responseEntity.getBody(), response.getOutputStream());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,6 @@ import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
|
||||
import org.springframework.web.servlet.config.annotation.ViewControllerRegistry;
|
||||
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
|
||||
|
||||
|
||||
|
||||
@Configuration
|
||||
public class WebMvcConfig implements WebMvcConfigurer {
|
||||
@Resource
|
||||
@@ -27,7 +25,7 @@ public class WebMvcConfig implements WebMvcConfigurer {
|
||||
public void addInterceptors(InterceptorRegistry registry) {
|
||||
if (CharSequenceUtil.isNotBlank(this.properties.getApiSecret())) {
|
||||
registry.addInterceptor(this.apiAuthorizeInterceptor)
|
||||
.addPathPatterns("/submit/**", "/task/**", "/account/**");
|
||||
.addPathPatterns("/mj/submit/**", "/mj/task/**", "/mj/account/**");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user