init v1.0.0

This commit is contained in:
ageer
2024-02-27 20:52:19 +08:00
parent 1f7f97e86a
commit a079ef44e5
602 changed files with 163057 additions and 95 deletions

View File

@@ -0,0 +1,59 @@
package com.xmzs.midjourney.util;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.thread.ThreadUtil;
import com.xmzs.midjourney.domain.DomainObject;
import lombok.experimental.UtilityClass;
import java.time.Duration;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@UtilityClass
public class AsyncLockUtils {
private static final TimedCache<String, LockObject> LOCK_MAP = CacheUtil.newTimedCache(Duration.ofDays(1).toMillis());
public static synchronized LockObject getLock(String key) {
return LOCK_MAP.get(key);
}
public static LockObject waitForLock(String key, Duration duration) throws TimeoutException {
LockObject lockObject;
synchronized (LOCK_MAP) {
if (!LOCK_MAP.containsKey(key)) {
LOCK_MAP.put(key, new LockObject(key));
}
lockObject = LOCK_MAP.get(key);
}
Future<?> future = ThreadUtil.execAsync(() -> {
try {
lockObject.sleep();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
try {
future.get(duration.toMillis(), TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException e) {
// do nothing
} catch (TimeoutException e) {
future.cancel(true);
throw new TimeoutException("Wait Timeout");
} finally {
LOCK_MAP.remove(lockObject.getId());
}
return lockObject;
}
public static class LockObject extends DomainObject {
public LockObject(String id) {
this.id = id;
}
}
}

View File

@@ -0,0 +1,43 @@
package com.xmzs.midjourney.util;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.exception.BannedPromptException;
import lombok.experimental.UtilityClass;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@UtilityClass
public class BannedPromptUtils {
private static final String BANNED_WORDS_FILE_PATH = "/home/spring/config/banned-words.txt";
private final List<String> BANNED_WORDS;
static {
List<String> lines;
File file = new File(BANNED_WORDS_FILE_PATH);
if (file.exists()) {
lines = FileUtil.readLines(file, StandardCharsets.UTF_8);
} else {
var resource = BannedPromptUtils.class.getResource("/banned-words.txt");
lines = FileUtil.readLines(resource, StandardCharsets.UTF_8);
}
BANNED_WORDS = lines.stream().filter(CharSequenceUtil::isNotBlank).toList();
}
public static void checkBanned(String promptEn) throws BannedPromptException {
String finalPromptEn = promptEn.toLowerCase(Locale.ENGLISH);
for (String word : BANNED_WORDS) {
Matcher matcher = Pattern.compile("\\b" + word + "\\b").matcher(finalPromptEn);
if (matcher.find()) {
int index = CharSequenceUtil.indexOfIgnoreCase(promptEn, word);
throw new BannedPromptException(promptEn.substring(index, index + word.length()));
}
}
}
}

View File

@@ -0,0 +1,9 @@
package com.xmzs.midjourney.util;
import lombok.Data;
@Data
public class ContentParseData {
protected String prompt;
protected String status;
}

View File

@@ -0,0 +1,85 @@
package com.xmzs.midjourney.util;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.enums.TaskAction;
import eu.maxschuster.dataurl.DataUrl;
import eu.maxschuster.dataurl.DataUrlSerializer;
import eu.maxschuster.dataurl.IDataUrlSerializer;
import lombok.experimental.UtilityClass;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@UtilityClass
public class ConvertUtils {
/**
* content正则匹配prompt和进度.
*/
public static final String CONTENT_REGEX = ".*?\\*\\*(.*?)\\*\\*.+<@\\d+> \\((.*?)\\)";
public static ContentParseData parseContent(String content) {
return parseContent(content, CONTENT_REGEX);
}
public static ContentParseData parseContent(String content, String regex) {
if (CharSequenceUtil.isBlank(content)) {
return null;
}
Matcher matcher = Pattern.compile(regex).matcher(content);
if (!matcher.find()) {
return null;
}
ContentParseData parseData = new ContentParseData();
parseData.setPrompt(matcher.group(1));
parseData.setStatus(matcher.group(2));
return parseData;
}
public static List<DataUrl> convertBase64Array(List<String> base64Array) throws MalformedURLException {
if (base64Array == null || base64Array.isEmpty()) {
return Collections.emptyList();
}
IDataUrlSerializer serializer = new DataUrlSerializer();
List<DataUrl> dataUrlList = new ArrayList<>();
for (String base64 : base64Array) {
DataUrl dataUrl = serializer.unserialize(base64);
dataUrlList.add(dataUrl);
}
return dataUrlList;
}
public static TaskChangeParams convertChangeParams(String content) {
List<String> split = CharSequenceUtil.split(content, " ");
if (split.size() != 2) {
return null;
}
String action = split.get(1).toLowerCase();
TaskChangeParams changeParams = new TaskChangeParams();
changeParams.setId(split.get(0));
if (action.charAt(0) == 'u') {
changeParams.setAction(TaskAction.UPSCALE);
} else if (action.charAt(0) == 'v') {
changeParams.setAction(TaskAction.VARIATION);
} else if (action.equals("r")) {
changeParams.setAction(TaskAction.REROLL);
return changeParams;
} else {
return null;
}
try {
int index = Integer.parseInt(action.substring(1, 2));
if (index < 1 || index > 4) {
return null;
}
changeParams.setIndex(index);
} catch (Exception e) {
return null;
}
return changeParams;
}
}

View File

@@ -0,0 +1,45 @@
package com.xmzs.midjourney.util;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.text.CharSequenceUtil;
import lombok.experimental.UtilityClass;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@UtilityClass
public class MimeTypeUtils {
private final Map<String, List<String>> MIME_TYPE_MAP;
static {
MIME_TYPE_MAP = new HashMap<>();
var resource = MimeTypeUtils.class.getResource("/mime.types");
var lines = FileUtil.readLines(resource, StandardCharsets.UTF_8);
for (var line : lines) {
if (CharSequenceUtil.isBlank(line)) {
continue;
}
var arr = line.split(":");
MIME_TYPE_MAP.put(arr[0], CharSequenceUtil.split(arr[1], ' '));
}
}
public static String guessFileSuffix(String mimeType) {
if (CharSequenceUtil.isBlank(mimeType)) {
return null;
}
String key = mimeType;
if (!MIME_TYPE_MAP.containsKey(key)) {
key = MIME_TYPE_MAP.keySet().stream().filter(k -> CharSequenceUtil.startWithIgnoreCase(mimeType, k))
.findFirst().orElse(null);
}
var suffixList = MIME_TYPE_MAP.get(key);
if (suffixList == null || suffixList.isEmpty()) {
return null;
}
return suffixList.iterator().next();
}
}

View File

@@ -0,0 +1,152 @@
package com.xmzs.midjourney.util;
import cn.hutool.core.exceptions.ValidateException;
import com.xmzs.midjourney.exception.SnowFlakeException;
import lombok.extern.slf4j.Slf4j;
import java.lang.management.ManagementFactory;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.util.Date;
import java.util.concurrent.ThreadLocalRandom;
@Slf4j
public class SnowFlake {
private long workerId;
private long datacenterId;
private long sequence = 0L;
private final long twepoch;
private final long sequenceMask;
private final long workerIdShift;
private final long datacenterIdShift;
private final long timestampLeftShift;
private long lastTimestamp = -1L;
private final boolean randomSequence;
private long count = 0L;
private final long timeOffset;
private final ThreadLocalRandom tlr = ThreadLocalRandom.current();
public static final SnowFlake INSTANCE = new SnowFlake();
private SnowFlake() {
this(false, 10, null, 5L, 5L, 12L);
}
private SnowFlake(boolean randomSequence, long timeOffset, Date epochDate, long workerIdBits, long datacenterIdBits, long sequenceBits) {
if (null != epochDate) {
this.twepoch = epochDate.getTime();
} else {
// 2012/12/12 23:59:59 GMT
this.twepoch = 1355327999000L;
}
long maxWorkerId = ~(-1L << workerIdBits);
long maxDatacenterId = ~(-1L << datacenterIdBits);
this.sequenceMask = ~(-1L << sequenceBits);
this.workerIdShift = sequenceBits;
this.datacenterIdShift = sequenceBits + workerIdBits;
this.timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
this.randomSequence = randomSequence;
this.timeOffset = timeOffset;
try {
this.datacenterId = getDatacenterId(maxDatacenterId);
this.workerId = getMaxWorkerId(datacenterId, maxWorkerId);
} catch (Exception e) {
log.warn("datacenterId or workerId generate error: {}, set default value", e.getMessage());
this.datacenterId = 4;
this.workerId = 1;
}
}
public synchronized String nextId() {
long currentTimestamp = timeGen();
if (currentTimestamp < this.lastTimestamp) {
long offset = this.lastTimestamp - currentTimestamp;
if (offset > this.timeOffset) {
throw new ValidateException("Clock moved backwards, refusing to generate id for [" + offset + "ms]");
}
try {
this.wait(offset << 1);
} catch (InterruptedException e) {
throw new SnowFlakeException(e);
}
currentTimestamp = timeGen();
if (currentTimestamp < this.lastTimestamp) {
throw new SnowFlakeException("Clock moved backwards, refusing to generate id for [" + offset + "ms]");
}
}
if (this.lastTimestamp == currentTimestamp) {
long tempSequence = this.sequence + 1;
if (this.randomSequence) {
this.sequence = tempSequence & this.sequenceMask;
this.count = (this.count + 1) & this.sequenceMask;
if (this.count == 0) {
currentTimestamp = this.tillNextMillis(this.lastTimestamp);
}
} else {
this.sequence = tempSequence & this.sequenceMask;
if (this.sequence == 0) {
currentTimestamp = this.tillNextMillis(lastTimestamp);
}
}
} else {
this.sequence = this.randomSequence ? this.tlr.nextLong(this.sequenceMask + 1) : 0L;
this.count = 0L;
}
this.lastTimestamp = currentTimestamp;
long id = ((currentTimestamp - this.twepoch) << this.timestampLeftShift) |
(this.datacenterId << this.datacenterIdShift) |
(this.workerId << this.workerIdShift) |
this.sequence;
return String.valueOf(id);
}
private static long getDatacenterId(long maxDatacenterId) {
long id = 0L;
try {
InetAddress ip = InetAddress.getLocalHost();
NetworkInterface network = NetworkInterface.getByInetAddress(ip);
if (network == null) {
id = 1L;
} else {
byte[] mac = network.getHardwareAddress();
if (null != mac) {
id = ((0x000000FF & (long) mac[mac.length - 1]) | (0x0000FF00 & (((long) mac[mac.length - 2]) << 8))) >> 6;
id = id % (maxDatacenterId + 1);
}
}
} catch (Exception e) {
throw new SnowFlakeException(e);
}
return id;
}
private static long getMaxWorkerId(long datacenterId, long maxWorkerId) {
StringBuilder macIpPid = new StringBuilder();
macIpPid.append(datacenterId);
try {
String name = ManagementFactory.getRuntimeMXBean().getName();
if (name != null && !name.isEmpty()) {
macIpPid.append(name.split("@")[0]);
}
String hostIp = InetAddress.getLocalHost().getHostAddress();
String ipStr = hostIp.replace("\\.", "");
macIpPid.append(ipStr);
} catch (Exception e) {
throw new SnowFlakeException(e);
}
return (macIpPid.toString().hashCode() & 0xffff) % (maxWorkerId + 1);
}
private long tillNextMillis(long lastTimestamp) {
long timestamp = timeGen();
while (timestamp <= lastTimestamp) {
timestamp = timeGen();
}
return timestamp;
}
private long timeGen() {
return System.currentTimeMillis();
}
}

View File

@@ -0,0 +1,11 @@
package com.xmzs.midjourney.util;
import com.xmzs.midjourney.enums.TaskAction;
import lombok.Data;
@Data
public class TaskChangeParams {
private String id;
private TaskAction action;
private Integer index;
}

View File

@@ -0,0 +1,10 @@
package com.xmzs.midjourney.util;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
public class UVContentParseData extends ContentParseData {
protected Integer index;
}