⭐⭐⭐ Spring Boot 项目实战 ⭐⭐⭐ Spring Cloud 项目实战
《Dubbo 实现原理与源码解析 —— 精品合集》 《Netty 实现原理与源码解析 —— 精品合集》
《Spring 实现原理与源码解析 —— 精品合集》 《MyBatis 实现原理与源码解析 —— 精品合集》
《Spring MVC 实现原理与源码解析 —— 精品合集》 《数据库实体设计合集》
《Spring Boot 实现原理与源码解析 —— 精品合集》 《Java 面试题 + Java 学习指南》

摘要: 原创出处 blog.csdn.net/u010978399/article/details/117771620 「吃素的哈士奇」欢迎转载,保留摘要,谢谢!


🙂🙂🙂关注**微信公众号:【芋道源码】**有福利:

  1. RocketMQ / MyCAT / Sharding-JDBC 所有源码分析文章列表
  2. RocketMQ / MyCAT / Sharding-JDBC 中文注释源码 GitHub 地址
  3. 您对于源码的疑问每条留言将得到认真回复。甚至不知道如何读源码也可以请教噢
  4. 新的源码解析文章实时收到通知。每周更新一篇左右
  5. 认真的源码交流微信群。

特别说明CountDownLatch

CountDownLatch是一个类springboot自带的类,可以直接用,变量AtomicBoolean 也是可以直接使用

CountDownLatch的用法

CountDownLatch典型用法:

1、某一线程在开始运行前等待n个线程执行完毕。 将CountDownLatch的计数器初始化为new CountDownLatch(n),每当一个任务线程执行完毕,就将计数器减1 countdownLatch.countDown(),当计数器的值变为0时,在CountDownLatch上await()的线程就会被唤醒。一个典型应用场景就是启动一个服务时,主线程需要等待多个组件加载完毕,之后再继续执行。

2、实现多个线程开始执行任务的最大并行性。 注意是并行性,不是并发,强调的是多个线程在某一时刻同时开始执行。类似于赛跑,将多个线程放到起点,等待发令枪响,然后同时开跑。做法是初始化一个共享的CountDownLatch(1),将其计算器初始化为1,多个线程在开始执行任务前首先countdownlatch.await(),当主线程调用countDown()时,计数器变为0,多个线程同时被唤醒。

CountDownLatch(num) 简单说明

new 一个 CountDownLatch(num) 对象

建立对象的时候 num 代表的是需要等待 num 个线程

// 建立对象的时候 num 代表的是需要等待 num 个线程
//主线程
CountDownLatch mainThreadLatch = new CountDownLatch(num);
//子线程
CountDownLatch rollBackLatch = new CountDownLatch(1);

主线程:mainThreadLatch.await() 和mainThreadLatch.countDown()

新建对象

CountDownLatch mainThreadLatch = new CountDownLatch(num);

卡住主线程,让其等待子线程,代码mainThreadLatch.await(),放在主线程里

mainThreadLatch.await();

代码mainThreadLatch.countDown(),放在子线程里,每一个子线程运行一到这个代码,意味着CountDownLatch(num),里面的num-1(自动减一)

mainThreadLatch.countDown();

CountDownLatch(num)里面的num减到0,也就是CountDownLatch(0),被卡住的主线程mainThreadLatch.await(),就会往下执行

子线程:rollBackLatch.await() 和rollBackLatch.countDown()

新建对象,特别注意:子线程这个num就是1(关于只能为1的解答在后面)

CountDownLatch rollBackLatch  = new CountDownLatch(1);

卡住子线程,阻止每一个子线程的事务提交和回滚

rollBackLatch.await();

代码rollBackLatch.countDown();放在主线程里,而且是放在主线程的等待代码mainThreadLatch.await();后面。

rollBackLatch.countDown();

为什么所有的子线程会在一瞬间就被所有都释放了?

事务的回滚是怎么结合进去的?

假设总共20个子线程,那么其中一个线程报错了怎么实现所有线程回滚。

引入变量

AtomicBoolean rollbackFlag = new AtomicBoolean(false)

和字面意思是一样的:根据 rollbackFlag 的true或者false 判断子线程里面,是否回滚。

首先我们确定的一点:rollbackFlag 是所有的子线程都用着这一个判断

主线程类Entry

package org.apache.dolphinscheduler.api.utils;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.dolphinscheduler.api.controller.WorkThread;
import org.apache.dolphinscheduler.common.enums.DbType;
import org.springframework.web.bind.annotation.*;

import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;


@RestController
@RequestMapping("importDatabase")
public class Entry {

/**
* @param dbid 数据库的id
* @param tablename 表名
* @param sftpFileName 文件名称
* @param head 是否有头文件
* @param splitSign 分隔符
* @param type 数据库类型
*/
private static String SFTP_HOST = "192.168.1.92";
private static int SFTP_PORT = 22;
private static String SFTP_USERNAME = "root";
private static String SFTP_PASSWORD = "rootroot";
private static String SFTP_BASEPATH = "/opt/testSFTP/";
@PostMapping("/thread")
@ResponseBody
public static JSONObject importDatabase(@RequestParam("dbid") int dbid
,@RequestParam("tablename") String tablename
,@RequestParam("sftpFileName") String sftpFileName
,@RequestParam("head") String head
,@RequestParam("splitSign") String splitSign
,@RequestParam("type") DbType type
,@RequestParam("heads") String heads
,@RequestParam("scolumns") String scolumns
,@RequestParam("tcolumns") String tcolumns ) throws Exception {
JSONObject obForRetrun = new JSONObject();

try {

JSONArray jsonArray = JSONArray.parseArray(tcolumns);
JSONArray scolumnArray = JSONArray.parseArray(scolumns);
JSONArray headsArray = JSONArray.parseArray(heads);
List<Integer> listInteger = getRrightDataNum(headsArray,scolumnArray);
JSONArray bodys = SFTPUtils.getSftpContent(SFTP_HOST,SFTP_PORT,SFTP_USERNAME,SFTP_PASSWORD,SFTP_BASEPATH,sftpFileName,head,splitSign);
int total = bodys.size();
int num = 20; //一个批次的数据有多少
int count = total/num;//周期
int lastNum =total- count*num;//余数

List<Thread> list = new ArrayList<Thread>();
SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");
TimeZone t = sdf.getTimeZone();
t.setRawOffset(0);
sdf.setTimeZone(t);
Long startTime=System.currentTimeMillis();


int countForCountDownLatch = 0;
if(lastNum==0){//整除
countForCountDownLatch= count;
}else{
countForCountDownLatch= count + 1;
}
//子线程
CountDownLatch rollBackLatch = new CountDownLatch(1);
//主线程
CountDownLatch mainThreadLatch = new CountDownLatch(countForCountDownLatch);

AtomicBoolean rollbackFlag = new AtomicBoolean(false);
StringBuffer message = new StringBuffer();
message.append("报错信息:");

//子线程
for(int i=0;i<count;i++) {//这里的count代表有几个线程
Thread g = new Thread(new WorkThread(i,num,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));
g.start();
list.add(g);
}

if(lastNum!=0){//有小数的情况下
Thread g = new Thread(new WorkThread(0,lastNum,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));
g.start();
list.add(g);
}

// for(Thread thread:list){
// System.out.println(thread.getState());
// thread.join();//是等待这个线程结束;
// }

mainThreadLatch.await();
//所有等待的子线程全部放开
rollBackLatch.countDown();

//是主线程等待子线程的终止。也就是说主线程的代码块中,如果碰到了t.join()方法,此时主线程需要等待(阻塞),等待子线程结束了(Waits for this thread to die.),才能继续执行t.join()之后的代码块。


Long endTime=System.currentTimeMillis();
System.out.println("总共用时: "+sdf.format(new Date(endTime-startTime)));

if(rollbackFlag.get()){
obForRetrun.put("code",500);
obForRetrun.put("msg",message);
}else{
obForRetrun.put("code",200);
obForRetrun.put("msg","提交成功!");
}
obForRetrun.put("data",null);
}catch (InterruptedException e){
e.printStackTrace();
obForRetrun.put("code",500);
obForRetrun.put("msg",e.getMessage());
obForRetrun.put("data",null);
}
return obForRetrun;

}

/**
* 文件里第几列被作为导出列
* @param headsArray
* @param scolumnArray
* @return
*/
public static List<Integer> getRrightDataNum(JSONArray headsArray, JSONArray scolumnArray){
List<Integer> list = new ArrayList<Integer>();
String arrayA [] = new String[headsArray.size()];
for(int i=0;i<headsArray.size();i++){
JSONObject ob = (JSONObject)headsArray.get(i);
arrayA[i] =String.valueOf(ob.get("title"));
}

String arrayB [] = new String[scolumnArray.size()];
for(int i=0;i<scolumnArray.size();i++){
JSONObject ob = (JSONObject)scolumnArray.get(i);
arrayB[i] =String.valueOf(ob.get("columnName"));
}

for(int i =0;i<arrayA.length;i++){
for(int j=0;j<arrayB.length;j++){
if(arrayA[i].equals(arrayB[j])){
list.add(i);
break;
}
}
}

return list;
}
}

子线程类WorkThread

package org.apache.dolphinscheduler.api.controller;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.dolphinscheduler.api.service.DataSourceService;
import org.apache.dolphinscheduler.common.enums.DbType;
import org.apache.dolphinscheduler.dao.entity.DataSource;
import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
import org.springframework.transaction.PlatformTransactionManager;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;


/**
* 多线程
*/
public class WorkThread implements Runnable{ //建立线程的两种方法 1 实现Runnable 接口 2 继承 Thread 类

private DataSourceService dataSourceService;

private DataSourceMapper dataSourceMapper;

private Integer begin;
private Integer end;
private String tableName;
private JSONArray columnArray;
private Integer dbid;
private DbType type;
private JSONArray bodys;
private List<Integer> listInteger;
private PlatformTransactionManager transactionManager;
private CountDownLatch mainThreadLatch;
private CountDownLatch rollBackLatch;
private AtomicBoolean rollbackFlag;
private StringBuffer message;



/**
* @param i
* @param num
* @param tableFrom
* @param columnArrayFrom
* @param dbidFrom
* @param typeFrom
*/
public WorkThread(int i, int num, String tableFrom, JSONArray columnArrayFrom, int dbidFrom
, DbType typeFrom, JSONArray bodysFrom, List<Integer> listIntegerFrom
,CountDownLatch mainThreadLatch,CountDownLatch rollBackLatch,AtomicBoolean rollbackFlag
,StringBuffer messageFrom) {
begin=i*num;
end=begin+num;
tableName = tableFrom;
columnArray = columnArrayFrom;
dbid = dbidFrom;
type = typeFrom;
bodys = bodysFrom;
listInteger = listIntegerFrom;
this.dataSourceMapper = SpringApplicationContext.getBean(DataSourceMapper.class);
this.dataSourceService = SpringApplicationContext.getBean(DataSourceService.class);
this.transactionManager = SpringApplicationContext.getBean(PlatformTransactionManager.class);
this.mainThreadLatch = mainThreadLatch;
this.rollBackLatch = rollBackLatch;
this.rollbackFlag = rollbackFlag;
this.message = messageFrom;
}

public void run() {

DataSource dataSource = dataSourceMapper.queryDataSourceByID(dbid);
String cp = dataSource.getConnectionParams();
Connection con=null;
con = dataSourceService.getConnection(type,cp);
if(con!=null)
{
SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");
TimeZone t = sdf.getTimeZone();
t.setRawOffset(0);
sdf.setTimeZone(t);
Long startTime = System.currentTimeMillis();
try {
con.setAutoCommit(false);

//---------------------------- 获取字段和类型
String columnString = null;//活动的字段
int intForType = 0;
String type[] = new String[columnArray.size()];//类型集合
for(int i=0;i<columnArray.size();i++){
JSONObject ob = (JSONObject)columnArray.get(i);
if(columnString==null){
columnString = String.valueOf(ob.get("name"));
}else{
columnString = columnString + ","+ String.valueOf(ob.get("name"));
}
type[intForType] = String.valueOf(ob.get("type"));
intForType = intForType + 1;
}
intForType = 0;

//这一步是为了形成 insert into "+tableName+"(id,name,age) values (?,?,?);
String dataString = null;
for(int i=0;i<columnArray.size();i++){
if(dataString==null){
dataString = "?";
}else{
dataString = dataString +","+"?";
}
}

//--------------------------------

StringBuffer sql = new StringBuffer();
sql = sql.append("insert into "+tableName+"("+columnString+") values ("+dataString+")") ;
PreparedStatement pst= (PreparedStatement)con.prepareStatement(sql.toString());
for(int i=begin;i<end;i++) {
JSONObject ob = (JSONObject)bodys.get(i);
if(ob!=null){
String [] array = ob.get(i).toString().split("\\,");
String [] arrayFinal = getFinalData(listInteger,array);
for(int j=0;j<type.length;j++){
String typeString = type[j].toLowerCase();
int z = j+1;
if("string".equals(typeString)||"varchar".equals(typeString)){
pst.setString(z,arrayFinal[j]);//这里的第一个参数 是指 替换第几个?
}else if("int".equals(typeString)||"bigint".equals(typeString)){
pst.setInt(z,Integer.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个?
}else if("long".equals(typeString)){
pst.setLong(z,Long.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个?
}else if("double".equals(typeString)){
pst.setDouble(z,Double.parseDouble(arrayFinal[j]));
}else if("date".equals(typeString)||"datetime".equals(typeString)){
pst.setDate(z, setDateback(arrayFinal[j]));
}else if("Timestamp".equals(typeString)){
pst.setTimestamp(z, setTimestampback(arrayFinal[j]));
}
}
}
pst.addBatch();
}
pst.executeBatch();

mainThreadLatch.countDown();
rollBackLatch.await();

if(rollbackFlag.get()){
con.rollback();//表示回滚事务;
}else{
con.commit();//事务提交
}
con.close();
} catch (Exception e) {
System.out.println(e.getMessage());
message = message.append(e.getMessage());
rollbackFlag.set(true);
mainThreadLatch.countDown();
try {
con.close();
} catch (SQLException throwables) {
throwables.printStackTrace();
}
}
Long endTime = System.currentTimeMillis();
System.out.println(Thread.currentThread().getName()+":startTime= "+sdf.format(new Date(startTime))+",endTime= "+sdf.format(new Date(endTime))
+" 用时:"+sdf.format(new Date(endTime - startTime)));

}
}


public java.sql.Date setDateback(String dateString) throws ParseException {
SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );
java.util.Date date = sdf.parse( "2015-5-6 10:30:00" );
long lg = date.getTime();// 日期 转 时间戳
return new java.sql.Date( lg );
}

public java.sql.Timestamp setTimestampback(String dateString) throws ParseException {
SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );
java.util.Date date = sdf.parse( "2015-5-6 10:30:00" );
long lg = date.getTime();// 日期 转 时间戳
return new java.sql.Timestamp( lg );
}

public String [] getFinalData(List<Integer> listInteger,String[] array){
String [] arrayFinal = new String [listInteger.size()];
for(int i=0;i<listInteger.size();i++){
int a = listInteger.get(i);
arrayFinal[i] = array[a];
}
return arrayFinal;
}
}

代码实际运用踩坑!!!!

还记得这里有个一批次处理多少数据么,我这边设置了20,实际到运用中的时候客户给了个20W的数据,我批次设置为20,那就有1W个子线程!!!!

这还不是最糟糕的,最糟糕的是每个子线程都会创建一个数据库连接,数据库直接被我搞炸了

所以这里需要把:

int num = 20; //一个批次的数据有多少

改成:

int num = 20000; //一个批次的数据有多少

文章目录
  1. 1. 特别说明CountDownLatch
  2. 2. CountDownLatch的用法
  3. 3. CountDownLatch(num) 简单说明
  4. 4. 主线程:mainThreadLatch.await() 和mainThreadLatch.countDown()
  5. 5. 子线程:rollBackLatch.await() 和rollBackLatch.countDown()
  6. 6. 为什么所有的子线程会在一瞬间就被所有都释放了?
  7. 7. 事务的回滚是怎么结合进去的?
  8. 8. 主线程类Entry
  9. 9. 子线程类WorkThread
  10. 10. 代码实际运用踩坑!!!!