摘要: 原创出处 blog.csdn.net/u010978399/article/details/117771620 「吃素的哈士奇」欢迎转载,保留摘要,谢谢!
特别说明CountDownLatch
CountDownLatch是一个类springboot自带的类,可以直接用 ,变量AtomicBoolean 也是可以直接使用
CountDownLatch的用法
CountDownLatch典型用法:
1、某一线程在开始运行前等待n个线程执行完毕。 将CountDownLatch的计数器初始化为new CountDownLatch(n)
,每当一个任务线程执行完毕,就将计数器减1 countdownLatch.countDown()
,当计数器的值变为0时,在CountDownLatch上await()
的线程就会被唤醒。一个典型应用场景就是启动一个服务时,主线程需要等待多个组件加载完毕,之后再继续执行。
2、实现多个线程开始执行任务的最大并行性。 注意是并行性,不是并发,强调的是多个线程在某一时刻同时开始执行。类似于赛跑,将多个线程放到起点,等待发令枪响,然后同时开跑。做法是初始化一个共享的CountDownLatch(1)
,将其计算器初始化为1,多个线程在开始执行任务前首先countdownlatch.await()
,当主线程调用countDown()
时,计数器变为0,多个线程同时被唤醒。
CountDownLatch(num) 简单说明
new 一个 CountDownLatch(num)
对象
建立对象的时候 num 代表的是需要等待 num 个线程
CountDownLatch mainThreadLatch = new CountDownLatch(num); CountDownLatch rollBackLatch = new CountDownLatch(1 );
主线程:mainThreadLatch.await() 和mainThreadLatch.countDown()
新建对象
CountDownLatch mainThreadLatch = new CountDownLatch(num);
卡住主线程,让其等待子线程,代码mainThreadLatch.await()
,放在主线程里
代码mainThreadLatch.countDown()
,放在子线程里,每一个子线程运行一到这个代码,意味着CountDownLatch(num)
,里面的num-1(自动减一)
mainThreadLatch.countDown();
CountDownLatch(num)
里面的num减到0,也就是CountDownLatch(0)
,被卡住的主线程mainThreadLatch.await()
,就会往下执行
子线程:rollBackLatch.await() 和rollBackLatch.countDown()
新建对象,特别注意:子线程这个num就是1(关于只能为1的解答在后面)
CountDownLatch rollBackLatch = new CountDownLatch(1 );
卡住子线程,阻止每一个子线程的事务提交和回滚
代码rollBackLatch.countDown();
放在主线程里,而且是放在主线程的等待代码mainThreadLatch.await();
后面。
rollBackLatch.countDown();
为什么所有的子线程会在一瞬间就被所有都释放了?
事务的回滚是怎么结合进去的?
假设总共20个子线程,那么其中一个线程报错了怎么实现所有线程回滚。
引入变量
AtomicBoolean rollbackFlag = new AtomicBoolean(false )
和字面意思是一样的:根据 rollbackFlag 的true或者false 判断子线程里面,是否回滚。
首先我们确定的一点:rollbackFlag 是所有的子线程都用着这一个判断
主线程类Entry
package org.apache.dolphinscheduler.api.utils;import com.alibaba.fastjson.JSONArray;import com.alibaba.fastjson.JSONObject;import org.apache.dolphinscheduler.api.controller.WorkThread;import org.apache.dolphinscheduler.common.enums.DbType;import org.springframework.web.bind.annotation.*;import java.text.SimpleDateFormat;import java.util.ArrayList;import java.util.Date;import java.util.List;import java.util.TimeZone;import java.util.concurrent.CountDownLatch;import java.util.concurrent.atomic.AtomicBoolean;@RestController @RequestMapping ("importDatabase" )public class Entry { private static String SFTP_HOST = "192.168.1.92" ; private static int SFTP_PORT = 22 ; private static String SFTP_USERNAME = "root" ; private static String SFTP_PASSWORD = "rootroot" ; private static String SFTP_BASEPATH = "/opt/testSFTP/" ; @PostMapping ("/thread" ) @ResponseBody public static JSONObject importDatabase (@RequestParam("dbid" ) int dbid ,@RequestParam ("tablename" ) String tablename ,@RequestParam ("sftpFileName" ) String sftpFileName ,@RequestParam ("head" ) String head ,@RequestParam ("splitSign" ) String splitSign ,@RequestParam ("type" ) DbType type ,@RequestParam ("heads" ) String heads ,@RequestParam ("scolumns" ) String scolumns ,@RequestParam ("tcolumns" ) String tcolumns ) throws Exception { JSONObject obForRetrun = new JSONObject(); try { JSONArray jsonArray = JSONArray.parseArray(tcolumns); JSONArray scolumnArray = JSONArray.parseArray(scolumns); JSONArray headsArray = JSONArray.parseArray(heads); List<Integer> listInteger = getRrightDataNum(headsArray,scolumnArray); JSONArray bodys = SFTPUtils.getSftpContent(SFTP_HOST,SFTP_PORT,SFTP_USERNAME,SFTP_PASSWORD,SFTP_BASEPATH,sftpFileName,head,splitSign); int total = bodys.size(); int num = 20 ; int count = total/num; int lastNum =total- count*num; List<Thread> list = new ArrayList<Thread>(); SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS" ); TimeZone t = sdf.getTimeZone(); t.setRawOffset(0 ); sdf.setTimeZone(t); Long startTime=System.currentTimeMillis(); int countForCountDownLatch = 0 ; if (lastNum==0 ){ countForCountDownLatch= count; }else { countForCountDownLatch= count + 1 ; } CountDownLatch rollBackLatch = new CountDownLatch(1 ); CountDownLatch mainThreadLatch = new CountDownLatch(countForCountDownLatch); AtomicBoolean rollbackFlag = new AtomicBoolean(false ); StringBuffer message = new StringBuffer(); message.append("报错信息:" ); for (int i=0 ;i<count;i++) { Thread g = new Thread(new WorkThread(i,num,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message )); g.start(); list.add(g); } if (lastNum!=0 ){ Thread g = new Thread(new WorkThread(0 ,lastNum,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message )); g.start(); list.add(g); } mainThreadLatch.await(); rollBackLatch.countDown(); Long endTime=System.currentTimeMillis(); System.out.println("总共用时: " +sdf.format(new Date(endTime-startTime))); if (rollbackFlag.get()){ obForRetrun.put("code" ,500 ); obForRetrun.put("msg" ,message); }else { obForRetrun.put("code" ,200 ); obForRetrun.put("msg" ,"提交成功!" ); } obForRetrun.put("data" ,null ); }catch (InterruptedException e){ e.printStackTrace(); obForRetrun.put("code" ,500 ); obForRetrun.put("msg" ,e.getMessage()); obForRetrun.put("data" ,null ); } return obForRetrun; } public static List<Integer> getRrightDataNum (JSONArray headsArray, JSONArray scolumnArray) { List<Integer> list = new ArrayList<Integer>(); String arrayA [] = new String[headsArray.size()]; for (int i=0 ;i<headsArray.size();i++){ JSONObject ob = (JSONObject)headsArray.get(i); arrayA[i] =String.valueOf(ob.get("title" )); } String arrayB [] = new String[scolumnArray.size()]; for (int i=0 ;i<scolumnArray.size();i++){ JSONObject ob = (JSONObject)scolumnArray.get(i); arrayB[i] =String.valueOf(ob.get("columnName" )); } for (int i =0 ;i<arrayA.length;i++){ for (int j=0 ;j<arrayB.length;j++){ if (arrayA[i].equals(arrayB[j])){ list.add(i); break ; } } } return list; } }
子线程类WorkThread
package org.apache.dolphinscheduler.api.controller;import com.alibaba.fastjson.JSONArray;import com.alibaba.fastjson.JSONObject;import org.apache.dolphinscheduler.api.service.DataSourceService;import org.apache.dolphinscheduler.common.enums.DbType;import org.apache.dolphinscheduler.dao.entity.DataSource;import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper;import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;import org.springframework.transaction.PlatformTransactionManager;import java.sql.Connection;import java.sql.PreparedStatement;import java.sql.SQLException;import java.text.ParseException;import java.text.SimpleDateFormat;import java.util.Date;import java.util.List;import java.util.TimeZone;import java.util.concurrent.CountDownLatch;import java.util.concurrent.atomic.AtomicBoolean;public class WorkThread implements Runnable { private DataSourceService dataSourceService; private DataSourceMapper dataSourceMapper; private Integer begin; private Integer end; private String tableName; private JSONArray columnArray; private Integer dbid; private DbType type; private JSONArray bodys; private List<Integer> listInteger; private PlatformTransactionManager transactionManager; private CountDownLatch mainThreadLatch; private CountDownLatch rollBackLatch; private AtomicBoolean rollbackFlag; private StringBuffer message; public WorkThread (int i, int num, String tableFrom, JSONArray columnArrayFrom, int dbidFrom , DbType typeFrom, JSONArray bodysFrom, List<Integer> listIntegerFrom ,CountDownLatch mainThreadLatch,CountDownLatch rollBackLatch,AtomicBoolean rollbackFlag ,StringBuffer messageFrom) { begin=i*num; end=begin+num; tableName = tableFrom; columnArray = columnArrayFrom; dbid = dbidFrom; type = typeFrom; bodys = bodysFrom; listInteger = listIntegerFrom; this .dataSourceMapper = SpringApplicationContext.getBean(DataSourceMapper.class); this .dataSourceService = SpringApplicationContext.getBean(DataSourceService.class); this .transactionManager = SpringApplicationContext.getBean(PlatformTransactionManager.class); this .mainThreadLatch = mainThreadLatch; this .rollBackLatch = rollBackLatch; this .rollbackFlag = rollbackFlag; this .message = messageFrom; } public void run () { DataSource dataSource = dataSourceMapper.queryDataSourceByID(dbid); String cp = dataSource.getConnectionParams(); Connection con=null ; con = dataSourceService.getConnection(type,cp); if (con!=null ) { SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS" ); TimeZone t = sdf.getTimeZone(); t.setRawOffset(0 ); sdf.setTimeZone(t); Long startTime = System.currentTimeMillis(); try { con.setAutoCommit(false ); String columnString = null ; int intForType = 0 ; String type[] = new String[columnArray.size()]; for (int i=0 ;i<columnArray.size();i++){ JSONObject ob = (JSONObject)columnArray.get(i); if (columnString==null ){ columnString = String.valueOf(ob.get("name" )); }else { columnString = columnString + "," + String.valueOf(ob.get("name" )); } type[intForType] = String.valueOf(ob.get("type" )); intForType = intForType + 1 ; } intForType = 0 ; String dataString = null ; for (int i=0 ;i<columnArray.size();i++){ if (dataString==null ){ dataString = "?" ; }else { dataString = dataString +"," +"?" ; } } StringBuffer sql = new StringBuffer(); sql = sql.append("insert into " +tableName+"(" +columnString+") values (" +dataString+")" ) ; PreparedStatement pst= (PreparedStatement)con.prepareStatement(sql.toString()); for (int i=begin;i<end;i++) { JSONObject ob = (JSONObject)bodys.get(i); if (ob!=null ){ String [] array = ob.get(i).toString().split("\\," ); String [] arrayFinal = getFinalData(listInteger,array); for (int j=0 ;j<type.length;j++){ String typeString = type[j].toLowerCase(); int z = j+1 ; if ("string" .equals(typeString)||"varchar" .equals(typeString)){ pst.setString(z,arrayFinal[j]); }else if ("int" .equals(typeString)||"bigint" .equals(typeString)){ pst.setInt(z,Integer.valueOf(arrayFinal[j])); }else if ("long" .equals(typeString)){ pst.setLong(z,Long.valueOf(arrayFinal[j])); }else if ("double" .equals(typeString)){ pst.setDouble(z,Double.parseDouble(arrayFinal[j])); }else if ("date" .equals(typeString)||"datetime" .equals(typeString)){ pst.setDate(z, setDateback(arrayFinal[j])); }else if ("Timestamp" .equals(typeString)){ pst.setTimestamp(z, setTimestampback(arrayFinal[j])); } } } pst.addBatch(); } pst.executeBatch(); mainThreadLatch.countDown(); rollBackLatch.await(); if (rollbackFlag.get()){ con.rollback(); }else { con.commit(); } con.close(); } catch (Exception e) { System.out.println(e.getMessage()); message = message.append(e.getMessage()); rollbackFlag.set(true ); mainThreadLatch.countDown(); try { con.close(); } catch (SQLException throwables) { throwables.printStackTrace(); } } Long endTime = System.currentTimeMillis(); System.out.println(Thread.currentThread().getName()+":startTime= " +sdf.format(new Date(startTime))+",endTime= " +sdf.format(new Date(endTime)) +" 用时:" +sdf.format(new Date(endTime - startTime))); } } public java.sql.Date setDateback (String dateString) throws ParseException { SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" ); java.util.Date date = sdf.parse( "2015-5-6 10:30:00" ); long lg = date.getTime(); return new java.sql.Date( lg ); } public java.sql.Timestamp setTimestampback (String dateString) throws ParseException { SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" ); java.util.Date date = sdf.parse( "2015-5-6 10:30:00" ); long lg = date.getTime(); return new java.sql.Timestamp( lg ); } public String [] getFinalData(List<Integer> listInteger,String[] array){ String [] arrayFinal = new String [listInteger.size()]; for (int i=0 ;i<listInteger.size();i++){ int a = listInteger.get(i); arrayFinal[i] = array[a]; } return arrayFinal; } }
代码实际运用踩坑!!!!
还记得这里有个一批次处理多少数据么,我这边设置了20,实际到运用中的时候客户给了个20W的数据,我批次设置为20,那就有1W个子线程!!!!
这还不是最糟糕的,最糟糕的是每个子线程都会创建一个数据库连接,数据库直接被我搞炸了
所以这里需要把:
改成: