协同过滤推荐算法(java原生JDK实现-附源码地址)

协同过滤推荐算法(java原生JDK实现-附源码地址)一、项目需求1.需求链接https://tianchi.aliyun.com/getStart/information.htm?raceId=2315222.需求内容竞赛题目在真实的业务场景下,我们往往需要对所有商品的一个子集构建个性化推荐模型。在完成这件任务的过程中,我们不仅需要利用用户在这

大家好,又见面了,我是你们的朋友全栈君。

一、项目需求

1.    需求链接

https://tianchi.aliyun.com/getStart/information.htm?raceId=231522

2.    需求内容

竞赛题目

在真实的业务场景下,我们往往需要对所有商品的一个子集构建个性化推荐模型。在完成这件任务的过程中,我们不仅需要利用用户在这个商品子集上的行为数据,往往还需要利用更丰富的用户行为数据。定义如下的符号:
U——用户集合
I——商品全集
P——商品子集, I
D——
用户对商品全集的行为数据集合
那么我们的目标是利用D来构造U中用户对P中商品的推荐模型。

数据说明

本场比赛提供20000用户的完整行为数据以及百万级的商品信息。竞赛数据包含两个部分。

第一部分是用户在商品全集上的移动端行为数据(D,表名为tianchi_fresh_comp_train_user_2w,包含如下字段:

字段

字段说明

提取说明

user_id

用户标识

抽样&字段脱敏

item_id

商品标识

字段脱敏

behavior_type

用户对商品的行为类型

包括浏览、收藏、加购物车、购买,对应取值分别是1234

user_geohash

用户位置的空间标识,可以为空

由经纬度通过保密的算法生成

item_category

商品分类标识

字段脱敏

time

行为时间

精确到小时级别

第二个部分是商品子集(P,表名为tianchi_fresh_comp_train_item_2w,包含如下字段: 

 字段

字段说明

提取说明

 item_id

 商品标识

 抽样&字段脱敏

 item_ geohash

 商品位置的空间标识,可以为空

 由经纬度通过保密的算法生成

 item_category 

 商品分类标识

 字段脱敏

训练数据包含了抽样出来的一定量用户在一个月时间(11.18~12.18)之内的移动端行为数据(D),评分数据是这些用户在这个一个月之后的一天(12.19)对商品子集(P)的购买数据。参赛者要使用训练数据建立推荐模型,并输出用户在接下来一天对商品子集购买行为的预测结果。 

评分数据格式
具体计算公式如下:参赛者完成用户对商品子集的购买预测之后,需要将结果放入指定格式的数据表(非分区表)中,要求结果表名为:tianchi_mobile_recommendation_predict.csv,且以utf-8格式编码;包含user_iditem_id两列(均为string类型),要求去除重复。例如:

协同过滤推荐算法(java原生JDK实现-附源码地址)

 

评估指标

比赛采用经典的精确度(precision)、召回率(recall)F1值作为评估指标。具体计算公式如下:

协同过滤推荐算法(java原生JDK实现-附源码地址)

 

其中PredictionSet为算法预测的购买数据集合,ReferenceSet为真实的答案购买数据集合。我们以F1值作为最终的唯一评测标准。

二、协同过滤推荐算法原理及实现流程

1.    基于用户的协同过滤推荐算法

 

基于用户的协同过滤推荐算法通过寻找与目标用户具有相似评分的邻居用户,通过查找邻居用户喜欢的项目,推测目标用户也具有相同的喜好。基于用户的协同过滤推荐算法基本思想是:根据用户-项目评分矩阵查找当前用户的最近邻居,利用最近邻居的评分来预测当前用户对项目的预测值,将评分最高的N个项目推荐给用户,其中的项目可理解为系统处理的商品。其算法流程图如下图1所示。

协同过滤推荐算法(java原生JDK实现-附源码地址)

图1基于用户的协同过滤推荐算法流程

基于用户的协同过滤推荐算法流程为:

1).构建用户项目评分矩阵

R={ , …… },T:m×n的用户评分矩阵,其中r={ , ,……, }为用户 的评分向量, 代表用户 对项目 的评分。

2).计算用户相似度

基于用户的协同过滤推荐算法,需查找与目标用户相似的用户。衡量用户之间的相似性需要计算每个用户的评分与其他用户评分的相似度,即评分矩阵中的用户评分记录。每个用户对项目的评分可以看作是一个n维的评分向量。使用评分向量计算目标用户 与其他用户 之间的相似度sim(i,j),通常计算用户相似度的方法有三种:余弦相似度、修正的余弦相似度和皮尔森相关系数。

3).构建最近邻居集

最近邻居集Neighor(u)中包含的是与目标用户具有相同爱好的其他用户。为选取邻居用户,我们首先计算目标用户u与其他用户v的相似度sim(u,v),再选择相似度最大的k个用户。用户相似度可理解为用户之间的信任值或推荐权重。通常,sim(u,v)∈[1,1]。用户相似度为1表示两个用户互相的推荐权重很大。如果为-1,表示两个用户的由于兴趣相差很大,因此互相的推荐权重很小。

4).预测评分计算

用户a 对项目i的预测评分p(a,i)为邻居用户对该项目评分的加权评分值。显然,不同用户对于目标用户的影响程度不同,所以在计算预测评分时,不同用户有不同的权重。计算时,我们选择用户相似度作为用户的权重因子,计算公式如下:

   协同过滤推荐算法(java原生JDK实现-附源码地址)

      基于用户的协同过滤推荐算法实现步骤为:

1).实时统计user对item的打分,从而生成user-item表(即构建用户-项目评分矩阵);

2).计算各个user之间的相似度,从而生成user-user的得分表,并进行排序;

3).对每一user的item集合排序;

4).针对预推荐的user,在user-user的得分表中选择与该用户最相似的N个用户,并在user-item表中选择这N个用户中已排序好的item集合中的topM;

5).此时的N*M个商品即为该用户推荐的商品集。

2.    基于项目的协同过滤推荐算法

基于项目的协同过滤推荐算法依据用户-项目评分矩阵通过计算项目之间的评分相似性来衡量项目评分相似性,找到与目标项目最相似的n个项目作为最近邻居集。然后通过对目标项目的相似邻居赋予一定的权重来预测当前项目的评分,再将得到的最终预测评分按序排列,将评分最高的N个项目推荐给当前用户,其中的项目可理解为系统处理的商品。其算法流程如下图2所示。

协同过滤推荐算法(java原生JDK实现-附源码地址)

图2基于项目的协同过滤推荐算法流程

基于项目的协同过滤推荐算法流程为:

首先,读取目标用户的评分记录集合 ;然后计算项目i与 中其他项目的相似度,选取k个最近邻居;根据评分相似度计算公式计算候选集中所有项目的预测评分;最后选取预测评分最高的N个项目推荐给用户。

基于项目的协同过滤推荐算法预测评分与其他用户评分的加权评分值相关,不同的历史评分项目与当前项目i的相关度有差异,所以在进行计算时,不同的项目有不同的权重。评分预测函数p(u,i),以项目相似度作为项目的权重因子,得到的评分公式如下:

协同过滤推荐算法(java原生JDK实现-附源码地址)

基于项目的协同过滤推荐算法实现步骤为:

1).实时统计user对item的打分,从而生成user-item表(即构建用户-项目评分矩阵);

2).计算各个item之间的相似度,从而生成item-item的得分表,并进行排序;

3).对每一user的item集合排序;

4).针对预推荐的user,在该用户已选择的item集合中,根据item-item表选择与已选item最相似的N个item;

5).此时的N个商品即为该用户推荐的商品集。

3.    基于用户的协同过滤推荐算法与基于项目的协同过滤推荐算法比较

基于用户的协同过滤推荐算法:

可以帮助用户发现新的商品,但需要较复杂的在线计算,需要处理新用户的问题。

基于项目的协同过滤推荐算法:

准确性好,表现稳定可控,便于离线计算,但推荐结果的多样性会差一些,一般不会带给用户惊喜性。

三、    项目实现

针对移动推荐,我们选择使用基于用户的协同过滤推荐算法来进行实现。

1.    数据模型及其实体类

用户行为数据:(user.csv)

user_id,item_id,behavior_type,user_geohash,item_category,time

10001082,285259775,1,97lk14c,4076,2014-12-08 18

10001082,4368907,1,,5503,2014-12-12 12

10001082,4368907,1,,5503,2014-12-12 12

10001082,53616768,1,,9762,2014-12-02 15

10001082,151466952,1,,5232,2014-12-12 11

10001082,53616768,4,,9762,2014-12-02 15

10001082,290088061,1,,5503,2014-12-12 12

10001082,298397524,1,,10894,2014-12-12 12

10001082,32104252,1,,6513,2014-12-12 12

10001082,323339743,1,,10894,2014-12-1212

商品信息:(item.csv)

item_id,item_geohash,item_category

100002303,,3368

100003592,,7995

100006838,,12630

100008089,,7791

100012750,,9614

100014072,,1032

100014463,,9023

100019387,,3064

100023812,,6700

package entity;

public class Item implements Comparable<Item> {
	private String itemId;
	private String itemGeoHash;
	private String itemCategory;
	private double weight;
	public String getItemId() {
		return itemId;
	}
	public void setItemId(String itemId) {
		this.itemId = itemId;
	}
	public String getItemGeoHash() {
		return itemGeoHash;
	}
	public void setItemGeoHash(String itemGeoHash) {
		this.itemGeoHash = itemGeoHash;
	}
	public String getItemCategory() {
		return itemCategory;
	}
	public void setItemCategory(String itemCategory) {
		this.itemCategory = itemCategory;
	}
	public double getWeight() {
		return weight;
	}
	public void setWeight(double weight) {
		this.weight = weight;
	}
	@Override
	public String toString() {
		return "Item [itemId=" + itemId + ", itemGeoHash=" + itemGeoHash
				+ ", itemCategory=" + itemCategory + ", weight=" + weight + "]";
	}
	@Override
	public int compareTo(Item o) {
		return (int) (-1 * (this.weight - o.weight));
	}
	
	
}
package entity;

public class Score implements Comparable<Score> {
	private String userId;      // 用户标识
	private String itemId;      // 商品标识
	private double score;
	public String getUserId() {
		return userId;
	}
	public void setUserId(String userId) {
		this.userId = userId;
	}
	public String getItemId() {
		return itemId;
	}
	public void setItemId(String itemId) {
		this.itemId = itemId;
	}
	public double getScore() {
		return score;
	}
	public void setScore(double score) {
		this.score = score;
	}
	@Override
	public String toString() {
		return "Score [userId=" + userId + ", itemId=" + itemId + ", score="
				+ score + "]";
	}
	@Override
	public int compareTo(Score o) {
		if ((this.score - o.score) < 0) {
			return 1;
		}else if ((this.score - o.score) > 0) {
			return -1;
		}else {
			return 0;
		}
	}
	@Override
	public int hashCode() {
		final int prime = 31;
		int result = 1;
		result = prime * result + ((itemId == null) ? 0 : itemId.hashCode());
		long temp;
		temp = Double.doubleToLongBits(score);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		result = prime * result + ((userId == null) ? 0 : userId.hashCode());
		return result;
	}
	@Override
	public boolean equals(Object obj) {
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		Score other = (Score) obj;
		if (itemId == null) {
			if (other.itemId != null)
				return false;
		} else if (!itemId.equals(other.itemId))
			return false;
		if (Double.doubleToLongBits(score) != Double
				.doubleToLongBits(other.score))
			return false;
		if (userId == null) {
			if (other.userId != null)
				return false;
		} else if (!userId.equals(other.userId))
			return false;
		return true;
	}
	
}
package entity;

public class User implements Comparable<User> {
	private String userId;      // 用户标识
	private String itemId;      // 商品标识
	private double behaviorType;   // 用户对商品的行为类型,可以为空,包括浏览、收藏、加购物车、购买,对应取值分别是1、2、3、4.
	private String userGeoHash; // 用户位置的空间标识
	private String itemCategory;// 商品分类标识
	private String time;        // 行为时间
	private int count;
	private double weight;      // 权重
	public String getUserId() {
		return userId;
	}
	public void setUserId(String userId) {
		this.userId = userId;
	}
	public String getItemId() {
		return itemId;
	}
	public void setItemId(String itemId) {
		this.itemId = itemId;
	}
	public double getBehaviorType() {
		return behaviorType;
	}
	public void setBehaviorType(double behaviorType) {
		this.behaviorType = behaviorType;
	}
	
	public String getUserGeoHash() {
		return userGeoHash;
	}
	public void setUserGeoHash(String userGeoHash) {
		this.userGeoHash = userGeoHash;
	}
	public String getItemCategory() {
		return itemCategory;
	}
	public void setItemCategory(String itemCategory) {
		this.itemCategory = itemCategory;
	}
	public String getTime() {
		return time;
	}
	public void setTime(String time) {
		this.time = time;
	}
	
	
	
	@Override
	public String toString() {
		return "User [userId=" + userId + ", itemId=" + itemId
				+ ", behaviorType=" + behaviorType + ", count=" + count + "]";
	}
	
	public int getCount() {
		return count;
	}
	public void setCount(int count) {
		this.count = count;
	}
	public double getWeight() {
		return weight;
	}
	public void setWeight(double weight) {
		this.weight = weight;
	}
	@Override
	public int compareTo(User o) {
		return (int)((-1) * (this.weight - o.weight));
	}
	
}

2.    工具类

文件处理工具:

package util;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import entity.Item;
import entity.Score;
import entity.User;


public class FileTool {
	
	public static FileReader fr=null;
	public static BufferedReader br=null;
	public static String line=null;
	
	public static FileOutputStream fos1 = null,fos2 = null,fos3 = null;
	public static PrintStream ps1 = null,ps2 = null,ps3 = null;
	
	public static int count = 0;
	
	/** 
	 * 初始化写文件器(单一指针)
	 * */
	public static void initWriter1(String writePath) {
		try {
			fos1 = new FileOutputStream(writePath);
			ps1 = new PrintStream(fos1);
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
	}
	/** 
	 * 关闭文件器(单一指针)
	 * */
	public static void closeReader() {
		try {
			br.close();
			fr.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}
	/** 
	 * 关闭文件器(单一指针)
	 * */
	public static void closeWriter1() {
		try {
			ps1.close();
			fos1.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}
	/** 
	 * 初始化写文件器(双指针)
	 * */
	public static void initWriter2(String writePath1,String writePath2) {
		try {
			fos1 = new FileOutputStream(writePath1);
			ps1 = new PrintStream(fos1);
			fos2 = new FileOutputStream(writePath2);
			ps2 = new PrintStream(fos2);
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
	}
	/** 
	 * 关闭文件器(双指针)
	 * */
	public static void closeWriter2() {
		try {
			ps1.close();
			fos1.close();
			ps2.close();
			fos2.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}
	/** 
	 * 初始化写文件器(三指针)
	 * */
	public static void initWriter3(String writePath1,String writePath2,String writePath3) {
		try {
			fos1 = new FileOutputStream(writePath1);
			ps1 = new PrintStream(fos1);
			fos2 = new FileOutputStream(writePath2);
			ps2 = new PrintStream(fos2);
			fos3 = new FileOutputStream(writePath3);
			ps3 = new PrintStream(fos3);
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
	}
	/** 
	 * 关闭文件器(三指针)
	 * */
	public static void closeWriter3() {
		try {
			ps1.close();
			fos1.close();
			ps2.close();
			fos2.close();
			ps3.close();
			fos3.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}
	
	public static List readFileOne(String path, boolean isTitle, String token, String pattern, String process) throws Exception {
		List<Object> ret = new ArrayList<Object>();
		
		fr = new FileReader(path);
		br = new BufferedReader(fr);
		int count = 0,i = 0;
		
		if (isTitle) {
			line = br.readLine();
			count++;
		}
		
		while((line = br.readLine()) != null){
			String[] strArr = line.split(token);
			switch (pattern) {
			case "item":
				ret.add(ParseTool.parseItem(strArr));
				break;
			case "user":
				ret.add(ParseTool.parseUser(strArr, process));
				break;
			default:
				ret.add(line);
				break;
			}
			count++;
			if (count/100000 == 1) {
				i++;
				System.out.println(100000*i);
				count = 0;
			}
		}
		
		closeReader();
		
		return ret;
	}
	
	public static void makeSampleData(String inputPath,boolean isTitle,String outputPath,int threshold) throws Exception {
		
		fr = new FileReader(inputPath);
		br = new BufferedReader(fr);
		initWriter1(outputPath);
		
		if (isTitle) {
			line = br.readLine();
		}
		int count = 0;
		while((line = br.readLine()) != null){
			ps1.println(line);
			count++;
			if (count == threshold) {
				break;
			}
		}
		closeReader();
	}
	public static List<String> traverseFolder(String dir) {
        File file = new File(dir);
        String[] fileList = null;
        if (file.exists()) {
        	fileList = file.list();
        }
        List<String> list = new ArrayList<String>();
        for(String path : fileList){
        	list.add(path);
        }
        return list;
    }
	public static Map<String, List<Score>> loadScoreMap(String path, boolean isTitle, String token, String type) throws Exception {
		fr = new FileReader(path);
		br = new BufferedReader(fr);
		
		if (isTitle) {
			line = br.readLine();
		}
		
		Map<String, List<Score>> scoreMap = new HashMap<String, List<Score>>();
		
		while((line = br.readLine()) != null){
			String[] arr = line.split(token);
			Score score = null;
			switch (type) {
			case "userCF":
				score = ParseTool.parseScoreByItemCF(arr);
				break;
			case "itemCF":
				score = ParseTool.parseScoreByUserCF(arr);
				break;
			default:
				break;
			}
			List<Score> temp = new ArrayList<Score>();
			if (scoreMap.containsKey(score.getItemId())) {
				temp = scoreMap.get(score.getItemId());
			}
			temp.add(score);
			scoreMap.put(score.getItemId(), temp);
		}
		closeReader();
		return scoreMap;
	}
	
	public static Map<String, List<String>> loadPredictData(String path, boolean isTitle, String token, int n) throws Exception {
		fr = new FileReader(path);
		br = new BufferedReader(fr);
		
		if (isTitle) {
			line = br.readLine();
		}
		Map<String, List<String>> map = new HashMap<String, List<String>>();
		while((line = br.readLine()) != null){
			String[] arr = line.split(token);
			String userId = arr[0];
			String itemId = arr[1];
			List<String> temp = new ArrayList<String>();
			if (map.containsKey(userId)) {
				temp = map.get(userId);
			}
			if (!temp.contains(itemId) && temp.size() <= n) {
				temp.add(itemId);
				map.put(userId, temp);
				count++;
			}
		}
		
		closeReader();
		return map;
	}
	
	public static Map<String, List<String>> loadTestData(Map<String, List<String>> predictMap, String dir, boolean isTitle, String token) throws Exception {
		
		List<String> fileList = traverseFolder(dir);
		Set<String> predictKeySet = predictMap.keySet();
		Map<String, List<String>> testMap = new HashMap<String, List<String>>();
		for(String predictKey : predictKeySet){
			if (fileList.contains(predictKey)) {
				List<String> itemList = getIdList(dir + predictKey, isTitle, token);
				testMap.put(predictKey, itemList);
			}
		}
		return testMap;
	}
	
	public static List<String> getIdList(String path, boolean isTitle, String token) throws Exception {
		fr = new FileReader(path);
		br = new BufferedReader(fr);
		
		if (isTitle) {
			line = br.readLine();
		}
		
		List<String> list = new ArrayList<String>();
		while((line = br.readLine()) != null){
			String[] arr = line.split(token);
			if (!list.contains(arr[0].trim())) {
				list.add(arr[0].trim());
			}
			count++;
		}
		
		closeReader();
		return list;
	}
	
	public static Map<String, Double> loadUser_ItemData(String path,boolean isTitle,String token) throws Exception {
		fr = new FileReader(path);
		br = new BufferedReader(fr);
		
		if (isTitle) {
			line = br.readLine();
		}
		Map<String, Double> map = new HashMap<String, Double>();
		while((line = br.readLine()) != null){
			String[] arr = line.split(token);
			String itemId = arr[1];
			double score = Double.valueOf(arr[2]);
			if(map.containsKey(itemId)){
				double temp = map.get(itemId);
				if (temp > score) {
					score = temp;
				}
			}
			map.put(itemId, score);
		}
		closeReader();
		return map;
	}
	
	public static void makeSampleData(String path, boolean isTitle,String token, List<String> userList, List<String> itemList) throws Exception {
		fr = new FileReader(path);
		br = new BufferedReader(fr);
		
		if (isTitle) {
			line = br.readLine();
		}
		
		while((line = br.readLine()) != null){
			String[] arr = line.split(token);
			String userId = arr[0];
			String itemId = arr[1];
			if (userList.contains(userId) && itemList.contains(itemId)) {
				FileTool.ps1.println(line);
			}
		}
		
		closeReader();
		
	}
	
}

解析工具:

package util;

import entity.Item;
import entity.Score;
import entity.User;

public class ParseTool {
	
	public static boolean isNumber(String str) {
		int i,n;
		n = str.length();
		for(i = 0;i < n;i++){
			if (!Character.isDigit(str.charAt(i))) {
				return false;
			}
		}
		return true;
	}
	
	public static Item parseItem(String[] contents) {
		Item item = new Item();
		if (contents[0] != null && !contents[0].isEmpty()) {
			item.setItemId(contents[0].trim());
		}
		if (contents[1] != null && !contents[1].isEmpty()) {
			item.setItemGeoHash(contents[1].trim());
		}
		if (contents[2] != null && !contents[2].isEmpty()) {
			item.setItemCategory(contents[2].trim());
		}
		return item;
	}
	public static User parseUser(String[] contents, String type) {
		User user = new User();
		int n = contents.length;
		if (contents[0] != null && !contents[0].isEmpty()) {
			user.setUserId(contents[0].trim());
		}
		if (contents[1] != null && !contents[1].isEmpty()) {
			user.setItemId(contents[1].trim());
		}
		if ("mapUser".equals(type)) {
			// 1.调用SpliteFileAndMakeScoreTable需放开,其它需注释
			if (contents[2] != null && !contents[2].isEmpty()) {
				user.setBehaviorType(Double.valueOf(contents[2].trim()));
			}
			/*
			// sample2
			if (contents[3] != null && !contents[3].isEmpty()) {
				user.setUserGeoHash(contents[3].trim());
			}
			if (contents[4] != null && !contents[4].isEmpty()) {
				user.setItemCategory(contents[4].trim());
			}
			if (contents[5] != null && !contents[5].isEmpty()) {
				user.setTime(contents[5].trim());
			}
			*/
			
			// movielens
			if (contents[3] != null && !contents[3].isEmpty()) {
				user.setTime(contents[3].trim());
			}
		}
		if ("reduceUser".equals(type)) {
			// 2.调用ReducefileTest需放开,其它需注释
			if (contents[2] != null && !contents[2].isEmpty()) {
				user.setBehaviorType(Double.parseDouble(contents[2].trim()));
			}
			if (contents[n-1] != null && !contents[n-1].isEmpty()) {
				user.setCount(Integer.valueOf(contents[n-1].trim()));
			}
		}
		if ("predict".equals(type)) {
			// 3.调用PredictTest需放开,其它需注释
			if (contents[n-1] != null && !contents[n-1].isEmpty()) {
				user.setWeight(Double.valueOf(contents[n-1].trim()));
			}
		}
		return user;
	}
	public static Score parseScoreByUserCF(String[] contents) {
		Score score = new Score();
		if (contents[0] != null && !contents[0].isEmpty()) {
			score.setUserId(contents[0].trim());
		}
		if (contents[1] != null && !contents[1].isEmpty()) {
			score.setItemId(contents[1].trim());
		}
		if (contents[2] != null && !contents[2].isEmpty() && !"Infinity".equals(contents[2])) {
			double score_temp = Double.parseDouble(contents[2].trim());
			if (score_temp < 0.0001) {
				score_temp = 0;
			}
			score.setScore(score_temp);
		}
		return score;
	}
	public static Score parseScoreByItemCF(String[] contents) {
		Score score = new Score();
		if (contents[0] != null && !contents[0].isEmpty()) {
			score.setItemId(contents[0].trim());
		}
		if (contents[1] != null && !contents[1].isEmpty()) {
			score.setUserId(contents[1].trim());
		}
		if (contents[2] != null && !contents[2].isEmpty() && !"Infinity".equals(contents[2])) {
			double score_temp = Double.parseDouble(contents[2].trim());
			if (score_temp < 0.0001) {
				score_temp = 0;
			}
			score.setScore(score_temp);
		}
		return score;
	}
}

3.    数据处理模块:

基于用户的协同过滤数据处理模块:

package service;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;

import util.FileTool;
import entity.Item;
import entity.Score;
import entity.User;

public class DataProcessByUserCF {
	
	public static final double[] w = {0,10,20,30}; 
	
	public static void output(Map<String, Map<String, List<User>>> userMap,String outputPath) {
		for(Entry<String, Map<String, List<User>>> entry : userMap.entrySet()){
			FileTool.initWriter1(outputPath + entry.getKey());
			Map<String, List<User>> temp = entry.getValue();
			for(Entry<String, List<User>> tempEntry : temp.entrySet()){
				List<User> users = tempEntry.getValue();
				int count = users.size();
				for(User user : users){
					FileTool.ps1.print(user.getUserId() + "\t");
					FileTool.ps1.print(user.getItemId() + "\t");
					FileTool.ps1.print(user.getBehaviorType() + "\t");
					//FileTool.ps1.print(user.getUserGeoHash() + "\t");
					//FileTool.ps1.print(user.getItemCategory() + "\t");
					//FileTool.ps1.print(user.getTime() + "\t");
					FileTool.ps1.print(count + "\n");
				}
			}
		}
		FileTool.closeWriter1();
	}
	
	public static void output(Map<String, Map<String, Double>> scoreTable, String outputPath, Set<String> userSet, Set<String> itemSet, String token) {
		FileTool.initWriter1(outputPath);
		
		for(String itemId: itemSet){
			FileTool.ps1.print(token + itemId);
		}
		FileTool.ps1.println();
		for(String userId : userSet){
			FileTool.ps1.print(userId + token);
			Map<String, Double> itemMap = scoreTable.get(userId);
			for(String itemId: itemSet){
				if(itemMap.containsKey(itemId)){
					FileTool.ps1.print(itemMap.get(itemId));
				}else {
					//FileTool.ps1.print(0);
				}
				FileTool.ps1.print(token);
			}
			FileTool.ps1.print("\n");
		}
	}
	
	public static void outputUser(List<User> userList) {
		for(User user : userList){
			FileTool.ps1.println(user.getUserId() + "\t" + user.getItemId() + "\t" + user.getWeight());
		}
	}
	
	public static void outputScore(List<Score> scoreList) {
		for(Score score : scoreList){
			FileTool.ps1.println(score.getUserId() + "\t" + score.getItemId() + "\t" + score.getScore());
		}
	}
	
	public static void outputRecommendList(Map<String, Set<String>> map) {
		for(Entry<String, Set<String>> entry : map.entrySet()){
			String userId = entry.getKey();
			Set<String> itemSet = entry.getValue();
			for(String itemId : itemSet){
				FileTool.ps1.println(userId + "," + itemId);
			}
		}
	}
	
	public static void output(Map<String, Set<String>> map) {
		for(Entry<String, Set<String>> entry : map.entrySet()){
			String userId = entry.getKey();
			Set<String> set = entry.getValue();
			for(String itemId : set){
				FileTool.ps1.println(userId + "\t" + itemId);
			}
		}
	}
	
	public static Map<String, Map<String, List<User>>> mapByUser(List<User> userList,Set<String> userSet,Set<String> itemSet) {
		Map<String, Map<String, List<User>>> userMap = new HashMap<>();
		for(User user: userList){
			Map<String, List<User>> tempMap = new HashMap<String, List<User>>();
			List<User> tempList = new ArrayList<User>();
			if (!userMap.containsKey(user.getUserId())) {
			}else {
				tempMap = userMap.get(user.getUserId());
				if (!tempMap.containsKey(user.getItemId())) {
				}else {
					tempList = tempMap.get(user.getItemId());
				}
			}
			tempList.add(user);
			tempMap.put(user.getItemId(), tempList);
			userMap.put(user.getUserId(), tempMap);
			userSet.add(user.getUserId());
			itemSet.add(user.getItemId());
			
		}
		return userMap;
	}
	
	public static Map<String, Map<String, Double>> makeScoreTable(Map<String, Map<String, List<User>>> userMap) {
		Map<String, Map<String, Double>> scoreTable = new HashMap<String, Map<String,Double>>();
		for(Entry<String, Map<String, List<User>>> userEntry : userMap.entrySet()){
			
			Map<String, List<User>> itemMap = userEntry.getValue();
			String userId = userEntry.getKey();
			
			Map<String, Double> itemScoreMap = new HashMap<String, Double>();
			
			for(Entry<String, List<User>> itemEntry : itemMap.entrySet()){
				String itemId = itemEntry.getKey();
				List<User> users = itemEntry.getValue();
				double weight = 0.0;
				
				double maxType = 0;
				for(User user : users){
					if (user.getBehaviorType() > maxType) {
						maxType = user.getBehaviorType();
					}
				}
				
				int count = users.size();
				if (maxType != 0) {
					//weight += w[(int)maxType-1]; //sample2
					weight += (maxType-1) * 10; //movielens
				}
				weight += count;
				
				itemScoreMap.put(itemId, weight);
			}
			scoreTable.put(userId, itemScoreMap);
		}
		return scoreTable;
	}
	public static double calculateWeight(double behaviorType, int count) {
		//double weight = w[(int)(behaviorType)-1] + count; //sample2
		double weight = (behaviorType - 1) * 10 + count; //movielens
		return weight;
	}
	public static List<User> reduceUserByItem(List<User> userList) {
		List<User> list = new ArrayList<User>();
		Map<String, User> userMap = new LinkedHashMap<String, User>();
		for(User user : userList){
			List<String> itemList = new ArrayList<String>();
			String itemId = user.getItemId();
			if (!userMap.containsKey(itemId)) {
				double weight = calculateWeight(user.getBehaviorType(), user.getCount());
				user.setWeight(weight);
				userMap.put(itemId, user);
				list.add(user);
			}else {
				User temp = userMap.get(itemId);
				double weight = calculateWeight(user.getBehaviorType(), user.getCount());
				if (temp.getWeight() < weight) {
					user.setWeight(weight);
					userMap.put(itemId, user);
					list.remove(temp);
					list.add(user);
				}
			}
		}
		userMap.clear();
		return list;
	}
	
	public static void sortScoreMap(Map<String, List<Score>> scoreMap) {
		Set<String> userSet = scoreMap.keySet();
		for(String userId : userSet){
			List<Score> temp = scoreMap.get(userId);
			Collections.sort(temp);
			scoreMap.put(userId, temp);
		}
	}
	public static Map<String, Set<String>> predict(Map<String, List<Score>> scoreMap, List<String> fileNameList, String userDir, int topNUser, double score_threshold, int topNItem, double weight_threshold) throws Exception {
		Map<String, Set<String>> recommendList = new HashMap<String, Set<String>>();
		for(Entry<String, List<Score>> entry : scoreMap.entrySet()){
			String userId1 = entry.getKey();
			List<Score> list = entry.getValue();
			int countUser = 0;
			Set<String> predictItemSet = new LinkedHashSet<String>();
			for(Score score : list){
				if (score.getScore() <= score_threshold) {
					break;
				}
				String userId2 = score.getUserId();
				if(fileNameList.contains(userId2)){
					List<User> userList = FileTool.readFileOne(userDir + userId2, false, "\t", "user", "predict");
					int countItem = 0;
					for(User user : userList){
						if (user.getWeight() <= weight_threshold) {
							continue;
						}
						predictItemSet.add(user.getItemId());
						countItem++;
						if (countItem == topNItem) {
							break;
						}
						
					}
					countUser++;
				}
				if (countUser == topNUser) {
					break;
				}
			}
			recommendList.put(userId1, predictItemSet);
		}
		return recommendList;
	}
	public static void prediction(Map<String, List<String>> predictMap,int predictN, Map<String, List<String>> referenceMap, int refN) {
		
		for(int i = 1;i <= 10;i++){
			int count = 0;
			predictN = 0;
			for(Entry<String, List<String>> predictEntity : predictMap.entrySet()){
				String userId = predictEntity.getKey();
				if (referenceMap.containsKey(userId)) {
					List<String> predictList = predictEntity.getValue();
					int j = 0;
					for(String itemId : predictList){
						predictN++;
						if (referenceMap.get(userId).contains(itemId)) {
							count++;
						}
						j++;
						if (j == i) {
							break;
						}
					}
				}
			}
			
			double precision = (1.0 * count / predictN) * 100;
			double recall = (1.0 * count / refN) * 100;
			if (recall > 100) {
				recall = 100;
			}
			double f1 = (2 * precision * recall)/(precision + recall);
			//System.out.println(predictN);
			//System.out.println(refN);
			System.out.println("推荐个数:" + i + ",中标个数:" + count + ",推荐总个数:" + predictN + ",真实购买个数:" + refN + ",precision:" + count + "/" + predictN + "=" + precision + "%,recall:" + count + "/" + refN + "=" + recall + "%,f1:" + f1 + "%");
			
		}
		
		
	}
	
}

 基于item的协同过滤数据处理模块:

package service;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;

import util.FileTool;
import entity.Item;
import entity.Score;
import entity.User;

public class DataProcessByItemCF {
	
	public static final double[] w = {0,10,20,30}; 
	
	public static void output(Map<String, Map<String, List<User>>> userMap,String outputPath) {
		for(Entry<String, Map<String, List<User>>> entry : userMap.entrySet()){
			FileTool.initWriter1(outputPath + entry.getKey());
			Map<String, List<User>> temp = entry.getValue();
			for(Entry<String, List<User>> tempEntry : temp.entrySet()){
				List<User> users = tempEntry.getValue();
				int count = users.size();
				for(User user : users){
					FileTool.ps1.print(user.getUserId() + "\t");
					FileTool.ps1.print(user.getItemId() + "\t");
					FileTool.ps1.print(user.getBehaviorType() + "\t");
					//FileTool.ps1.print(user.getUserGeoHash() + "\t");
					//FileTool.ps1.print(user.getItemCategory() + "\t");
					//FileTool.ps1.print(user.getTime() + "\t");
					FileTool.ps1.print(count + "\n");
				}
			}
			FileTool.closeWriter1();
		}
	}
	
	public static void output(Map<String, Map<String, Double>> scoreTable, String outputPath, Set<String> userSet, Set<String> itemSet, String token) {
		FileTool.initWriter1(outputPath);
		
		for(String userId: userSet){
			FileTool.ps1.print(token + userId);
		}
		FileTool.ps1.println();
		for(String itemId : itemSet){
			FileTool.ps1.print(itemId + token);
			Map<String, Double> userMap = scoreTable.get(itemId);
			for(String userId: userSet){
				if(userMap.containsKey(userId)){
					FileTool.ps1.print(userMap.get(userId));
				}else {
					//FileTool.ps1.print(0);
				}
				FileTool.ps1.print(token);
			}
			FileTool.ps1.print("\n");
		}
		
		FileTool.closeWriter1();
	}
	
	public static void outputUser(List<User> userList) {
		for(User user : userList){
			FileTool.ps1.println(user.getUserId() + "\t" + user.getItemId() + "\t" + user.getWeight());
		}
	}
	
	public static void outputScore(List<Score> scoreList) {
		for(Score score : scoreList){
			FileTool.ps1.println(score.getUserId() + "\t" + score.getItemId() + "\t" + score.getScore());
		}
	}
	
	public static void outputRecommendList(Map<String, Set<String>> map) {
		for(Entry<String, Set<String>> entry : map.entrySet()){
			String userId = entry.getKey();
			Set<String> itemSet = entry.getValue();
			for(String itemId : itemSet){
				FileTool.ps1.println(userId + "," + itemId);
			}
		}
	}
	
	public static void output(Map<String, Set<String>> map) {
		for(Entry<String, Set<String>> entry : map.entrySet()){
			String userId = entry.getKey();
			Set<String> set = entry.getValue();
			for(String itemId : set){
				FileTool.ps1.println(userId + "\t" + itemId);
			}
		}
	}
	
	public static Map<String, Map<String, List<User>>> mapByUser(List<User> userList, Set<String> userSet, Set<String> itemSet) {
		Map<String, Map<String, List<User>>> userMap = new HashMap<>();
		for(User user: userList){
			Map<String, List<User>> tempMap = new HashMap<String, List<User>>();
			List<User> tempList = new ArrayList<User>();
			if (!userMap.containsKey(user.getUserId())) {
			}else {
				tempMap = userMap.get(user.getUserId());
				if (!tempMap.containsKey(user.getItemId())) {
				}else {
					tempList = tempMap.get(user.getItemId());
				}
			}
			tempList.add(user);
			tempMap.put(user.getItemId(), tempList);
			userMap.put(user.getUserId(), tempMap);
			//userMap.put(user.getItemId(), tempMap);
			userSet.add(user.getUserId());
			itemSet.add(user.getItemId());
			
		}
		return userMap;
	}
	
	public static Map<String, Map<String, Double>> makeScoreTable(Map<String, Map<String, List<User>>> userMap) {
		Map<String, Map<String, Double>> scoreTable = new LinkedHashMap<>();
		for(Entry<String, Map<String, List<User>>> userEntry : userMap.entrySet()){
			
			Map<String, List<User>> itemMap = userEntry.getValue();
			String userId = userEntry.getKey();
			
			Map<String, Double> userScoreMap;
			
			for(Entry<String, List<User>> itemEntry : itemMap.entrySet()){
				String itemId = itemEntry.getKey();
				List<User> users = itemEntry.getValue();
				double weight = 0.0;
				
				double maxType = 0;
				for(User user : users){
					if (user.getBehaviorType() > maxType) {
						maxType = user.getBehaviorType();
					}
				}
				
				int count = users.size();
				if (maxType != 0) {
					//weight += w[(int)maxType-1]; //sample2
					weight += (maxType-1) * 10; //movielens
				}
				weight += count;
				
				if (scoreTable.containsKey(itemId)) {
					userScoreMap = scoreTable.get(itemId);
				}else {
					userScoreMap = new LinkedHashMap<String, Double>();
				}
				
				userScoreMap.put(userId, weight);
				
				scoreTable.put(itemId, userScoreMap);
				
			}
			
		}
		return scoreTable;
	}
	public static double calculateWeight(double behaviorType, int count) {
		//double weight = w[(int)(behaviorType)-1] + count; //sample2
		double weight = (behaviorType - 1) * 10 + count; //movielens
		return weight;
	}
	public static List<User> reduceUserByItem(List<User> userList) {
		List<User> list = new ArrayList<User>();
		Map<String, User> userMap = new LinkedHashMap<String, User>();
		for(User user : userList){
			List<String> itemList = new ArrayList<String>();
			String itemId = user.getItemId();
			if (!userMap.containsKey(itemId)) {
				double weight = calculateWeight(user.getBehaviorType(), user.getCount());
				user.setWeight(weight);
				userMap.put(itemId, user);
				list.add(user);
			}else {
				User temp = userMap.get(itemId);
				double weight = calculateWeight(user.getBehaviorType(), user.getCount());
				if (temp.getWeight() < weight) {
					user.setWeight(weight);
					userMap.put(itemId, user);
					list.remove(temp);
					list.add(user);
				}
			}
		}
		userMap.clear();
		return list;
	}
	
	public static void sortScoreMap(Map<String, List<Score>> scoreMap) {
		Set<String> userSet = scoreMap.keySet();
		for(String userId : userSet){
			List<Score> temp = scoreMap.get(userId);
			Collections.sort(temp);
			scoreMap.put(userId, temp);
		}
	}
	public static Map<String, Set<String>> predict(Map<String, List<Score>> scoreMap, List<String> fileNameList, String userDir, int topNUser, double score_threshold, int topNItem, double weight_threshold) throws Exception {
		Map<String, Set<String>> recommendList = new HashMap<String, Set<String>>();
		
		for(String userFileName : fileNameList){
			//System.out.println("userFileName:"+userFileName);
			List<User> userList = FileTool.readFileOne(userDir + userFileName, false, "\t", "user", "predict");
			//System.out.println("userList:"+userList);
			for(User user : userList){
				String userId = user.getUserId();
				if (user.getWeight() <= weight_threshold) {
					continue;
				}
				String itemId = user.getItemId();
				List<Score> itemSimilarList = scoreMap.get(itemId);
				//System.out.println("itemSimilarList:"+itemSimilarList);
				Set<String> predictItemSet;
				
				if (recommendList.containsKey(userId)) {
					predictItemSet = recommendList.get(userId);
				}else {
					predictItemSet = new LinkedHashSet<String>();
				}
				
				int countItem = 0;
				for(Score similarItem : itemSimilarList){
					predictItemSet.add(similarItem.getUserId());
					if (countItem == topNItem) {
						break;
					}
					countItem++;
				}
				recommendList.put(userId, predictItemSet);
			}
		}

		return recommendList;
	}
	public static void prediction(Map<String, List<String>> predictMap,int predictN, Map<String, List<String>> referenceMap, int refN) {
		int count = 0;
		for(Entry<String, List<String>> predictEntity : predictMap.entrySet()){
			String userId = predictEntity.getKey();
			if (referenceMap.containsKey(userId)) {
				List<String> predictList = predictEntity.getValue();
				for(String itemId : predictList){
					if (referenceMap.get(userId).contains(itemId)) {
						count++;
					}
				}
			}
		}
		double precision = (1.0 * count / predictN) * 100;
		double recall = (1.0 * count / refN) * 100;
		double f1 = (2 * precision * recall)/(precision + recall);
		System.out.println("precision="+precision+",recall="+recall+",f1="+f1);
	}
	
}

4.    计算模块

基于用户的协同过滤计算相似度模块

package service;

import java.text.DecimalFormat;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import util.FileTool;

public class CalculateSimilarityByUserCF {
	
	public static DecimalFormat df = new DecimalFormat("#.0000");
	
	//计算方式有问题,如(a):1,0,1与1,20,1计算相似度和(b):0,1,0与0,0,0,(b)反而比(a)近
	public static double EuclidDist(Map<String, Double> userMap1,
			Map<String, Double> userMap2, Set<String> userSet,
			Set<String> itemSet) {
		double sum = 0;
		for (String itemId : itemSet) {
			double score1 = 0.0;
			double score2 = 0.0;
			if (userMap1.containsKey(itemId) && userMap2.containsKey(itemId)) {
				score1 = userMap1.get(itemId);
				score2 = userMap2.get(itemId);
			} else if (userMap1.containsKey(itemId)) {
				score1 = userMap1.get(itemId);
			} else if (userMap2.containsKey(itemId)) {
				score2 = userMap2.get(itemId);
			}
			double temp = Math.pow((score1 - score2), 2);
			sum += temp;
		}
		sum = Math.sqrt(sum);
		return sum;
	}

	public static double CosineDist(Map<String, Double> userMap1,
			Map<String, Double> userMap2, Set<String> userSet,
			Set<String> itemSet) {
		double dist = 0;
		double numerator = 1; // 分子
		double denominator1 = 0; // 分母
		double denominator2 = 0; // 分母
		for (String itemId : itemSet) {
			double score1 = 0.0;
			double score2 = 0.0;
			if (userMap1.containsKey(itemId) && userMap2.containsKey(itemId)) {
				//numerator++;
				score1 = userMap1.get(itemId);
				score2 = userMap2.get(itemId);
				numerator = numerator * (score1 * score2);
			} else if (userMap1.containsKey(itemId)) {
				score1 = userMap1.get(itemId);
			} else if (userMap2.containsKey(itemId)) {
				score2 = userMap2.get(itemId);
			}
			denominator1 += Math.pow(score1, 2);
			denominator2 += Math.pow(score2, 2);
		}
		dist = ((1.0 * numerator) / (Math.sqrt(denominator1) * Math.sqrt(denominator2)));
		return dist;
	}
	
	public static double execute(Map<String, Double> userMap1,Map<String, Double> userMap2,Set<String> userSet,Set<String> itemSet) {
		/*
		double dist = EuclidDist(userMap1, userMap2, userSet, itemSet);
		double userScore = 1.0 / (1.0 + dist);
		*/
		double userScore = CosineDist(userMap1, userMap2, userSet, itemSet);
		return userScore;
	}

	public static void execute(String userId,Map<String, Map<String, Double>> scoreTable,
			Set<String> userSet, Set<String> itemSet) {
		for (Entry<String, Map<String, Double>> userEntry : scoreTable.entrySet()) {
			String userId2 = userEntry.getKey();
			Map<String, Double> userMap2 = userEntry.getValue();
			//double dist = EuclidDist(scoreTable.get(userId), userMap2, userSet, itemSet);
			//double userScore = 1.0 / (1.0 + dist);
			double userScore = CosineDist(scoreTable.get(userId), userMap2, userSet, itemSet);
			FileTool.ps1.println(userId + "," + userId2 + "," + df.format(userScore));
		}
	}

	public static void execute(Map<String, Map<String, Double>> scoreTable,
			Set<String> userSet, Set<String> itemSet) {
		int count = 0;
		for (Entry<String, Map<String, Double>> userEntry1 : scoreTable.entrySet()) {
			long startTime = System.currentTimeMillis();
			String userId = userEntry1.getKey();
			execute(userId, scoreTable, userSet, itemSet);
			count++;
			long endTime = System.currentTimeMillis();
			long dur = endTime - startTime;
			System.out.println("user count:" + count + ",dur time:" + dur + "," + (dur/1000) + "s," + (dur/1000/60) + "m");
		}
	}

}

 基于item的协同过滤计算相似度模块

package service;

import java.text.DecimalFormat;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import util.FileTool;

public class CalculateSimilarityByItemCF {
	
	public static DecimalFormat df = new DecimalFormat("#.0000");
	
	//计算方式有问题,如(a):1,0,1与1,20,1计算相似度和(b):0,1,0与0,0,0,(b)反而比(a)近
	public static double EuclidDist(Map<String, Double> userMap1,
			Map<String, Double> userMap2, Set<String> userSet,
			Set<String> itemSet) {
		double sum = 0;
		for (String userId : userSet) {
			double score1 = 0.0;
			double score2 = 0.0;
			if (userMap1.containsKey(userId) && userMap2.containsKey(userId)) {
				score1 = userMap1.get(userId);
				score2 = userMap2.get(userId);
			} else if (userMap1.containsKey(userId)) {
				score1 = userMap1.get(userId);
			} else if (userMap2.containsKey(userId)) {
				score2 = userMap2.get(userId);
			}
			double temp = Math.pow((score1 - score2), 2);
			sum += temp;
		}
		sum = Math.sqrt(sum);
		return sum;
	}

	public static double CosineDist(Map<String, Double> userMap1,
			Map<String, Double> userMap2, Set<String> userSet,
			Set<String> itemSet) {
		double dist = 0;
		double numerator = 1; // 分子
		double denominator1 = 0; // 分母
		double denominator2 = 0; // 分母
		for (String userId : userSet) {
			double score1 = 0.0;
			double score2 = 0.0;
			if (userMap1.containsKey(userId) && userMap2.containsKey(userId)) {
				//numerator++;
				score1 = userMap1.get(userId);
				score2 = userMap2.get(userId);
				numerator = numerator * (score1 * score2);
			} else if (userMap1.containsKey(userId)) {
				score1 = userMap1.get(userId);
			} else if (userMap2.containsKey(userId)) {
				score2 = userMap2.get(userId);
			}
			denominator1 += Math.pow(score1, 2);
			denominator2 += Math.pow(score2, 2);
		}
		dist = ((1.0 * numerator) / (Math.sqrt(denominator1) * Math.sqrt(denominator2)));
		return dist;
	}
	
	public static double SimpleDist(Map<String, Double> userMap1,
			Map<String, Double> userMap2, Set<String> userSet,
			Set<String> itemSet) {
		double dist = 0;
		double numerator = 0; // 分子
		double denominator1 = 0; // 分母
		double denominator2 = 0; // 分母
		for (String userId : userSet) {
			if (userMap1.containsKey(userId) && userMap2.containsKey(userId)) {
				numerator++;
				denominator1++;
				denominator2++;
			} else if (userMap1.containsKey(userId)) {
				denominator1++;
			} else if (userMap2.containsKey(userId)) {
				denominator2++;
			}
		}
		dist = ((1.0 * numerator) / (Math.sqrt(denominator1) * Math.sqrt(denominator2)));
		return dist;
	}
	
	public static double execute(Map<String, Double> userMap1,Map<String, Double> userMap2,Set<String> userSet,Set<String> itemSet) {
		/*
		double dist = EuclidDist(userMap1, userMap2, userSet, itemSet);
		double userScore = 1.0 / (1.0 + dist);
		*/
		double userScore = CosineDist(userMap1, userMap2, userSet, itemSet);
		return userScore;
	}

	public static void execute(String itemId, Map<String, Map<String, Double>> scoreTable,
			Set<String> userSet, Set<String> itemSet) {
		for (Entry<String, Map<String, Double>> itemEntry : scoreTable.entrySet()) {
			String itemId2 = itemEntry.getKey();
			Map<String, Double> userMap2 = itemEntry.getValue();
			//double dist = EuclidDist(scoreTable.get(itemId), userMap2, userSet, itemSet);
			//double userScore = 1.0 / (1.0 + dist);
			//double userScore = CosineDist(scoreTable.get(itemId), userMap2, userSet, itemSet);
			double userScore = SimpleDist(scoreTable.get(itemId), userMap2, userSet, itemSet);
			FileTool.ps1.println(itemId + "," + itemId2 + "," + df.format(userScore));
		}
	}

	public static void execute(Map<String, Map<String, Double>> scoreTable,
			Set<String> userSet, Set<String> itemSet) {
		int count = 0;
		
		for (Entry<String, Map<String, Double>> itemEntry : scoreTable.entrySet()) {
			long startTime = System.currentTimeMillis();
			String itemId = itemEntry.getKey();
			execute(itemId, scoreTable, userSet, itemSet);
			count++;
			long endTime = System.currentTimeMillis();
			long dur = endTime - startTime;
			System.out.println("item count:" + count + ",dur time:" + dur + "," + (dur/1000) + "s," + (dur/1000/60) + "m");
		}
		
	}

}

5.    脚本

生成userset和itemset:

package script;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

import entity.User;
import util.FileTool;

public class MakeSet {

	public static void main(String[] args) throws Exception {
		String inputDir = args[0];
		String outputDir = args[1];
		Set<String> userSet = new HashSet<String>();
		Set<String> itemSet = new HashSet<String>();
		List<String> pathList = FileTool.traverseFolder(inputDir);
		for(String path : pathList){
			String inputPath = inputDir + path;
			List<User> list = FileTool.readFileOne(inputPath, false, "\t", "user", "reduceUser");
			for(User user : list){
				userSet.add(user.getUserId());
				itemSet.add(user.getItemId());
			}
		}
		FileTool.initWriter1(outputDir+"userSet");
		for(String userId : userSet){
			FileTool.ps1.println(userId);
		}
		FileTool.closeWriter1();
		FileTool.initWriter1(outputDir+"itemSet");
		for(String itemId : itemSet){
			FileTool.ps1.println(itemId);
		}
		FileTool.closeWriter1();
		
	}

}

基于用户的协同过滤:map文件构建user-item评分矩阵并计算user间的相似度生成user-user的得分表:

package script.runByUserCF;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import service.CalculateSimilarityByUserCF;
import service.DataProcessByUserCF;
import util.FileTool;
import entity.User;

public class SpliteFileAndMakeScoreTable {
	
	public static void main(String[] args) throws Exception {
		
		String inputDir = "data/fresh_comp_offline/sample2/";
		String outputDir = "data/fresh_comp_offline/sample2/out/user/";
		String userPath = inputDir + "train_user.csv";
		String outputPath = "data/fresh_comp_offline/sample2/out/score.csv";
		String scoreTablePath = "data/fresh_comp_offline/sample2/out/scoreTable.csv";
		
		/*
		String inputDir = "data/ml-1m/";
		String outputDir = "data/ml-1m/out/user/";
		String userPath = inputDir + "ratings.dat";
		String outputPath = "data/ml-1m/out/score.csv";
		*/
		/*
		String inputDir = args[0];
		String outputDir = args[1];
		String userPath = inputDir + args[2];
		String outputPath = args[3];
		String scoreTablePath = args[4];
		*/
		
		//FileTool.makeSampleData(userPath, true, outputPath, 10000);
		//List<Object> itemList = FileTool.readFileOne(itemPath, true, ",", "item");
		List<User> userList = FileTool.readFileOne(userPath, true, ",", "user", "mapUser"); //sample2
		//List<User> userList = FileTool.readFileOne(userPath, true, "::", "user", "mapUser"); //movielens
		Set<String> userSet = new HashSet<String>();
		Set<String> itemSet = new HashSet<String>();
		Map<String, Map<String, List<User>>> userMap = DataProcessByUserCF.mapByUser(userList,userSet,itemSet);
		userList.clear();
		DataProcessByUserCF.output(userMap, outputDir);
		
		//生成userToItem的打分表
		Map<String, Map<String, Double>> scoreTable = DataProcessByUserCF.makeScoreTable(userMap);
		DataProcessByUserCF.output(scoreTable, scoreTablePath , userSet, itemSet, ",");
		userMap.clear();
		FileTool.initWriter1(outputPath);
		CalculateSimilarityByUserCF.execute(scoreTable, userSet, itemSet);
		FileTool.closeWriter1();
		
	}

}

基于用户的协同过滤,reduce文件,对users排序:

package script.runByUserCF;

import java.util.Collections;
import java.util.List;

import service.DataProcessByUserCF;
import util.FileTool;
import entity.User;

public class ReduceFileTest {
	public static void main(String[] args) throws Exception {
		
		String inputDir = "data/fresh_comp_offline/sample2/out/user/";
		String outputDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
		
		/*
		String inputDir = "data/ml-1m/out/user/";
		String outputDir = "data/ml-1m/out/sorteduser/";
		*/
		/*
		String inputDir = args[0];
		String outputDir = args[1];
		*/
		
		List<String> pathList = FileTool.traverseFolder(inputDir);
		for(String path : pathList){
			List<User> userList = FileTool.readFileOne(inputDir+path, false, "\t", "user", "reduceUser");
			List<User> list = DataProcessByUserCF.reduceUserByItem(userList);
			userList.clear();
			FileTool.initWriter1(outputDir + path);
			Collections.sort(list);
			DataProcessByUserCF.outputUser(list);
			FileTool.closeWriter1();
			list.clear();
		}
	}
}

基于用户的协同过滤,为用户进行推荐,生成预测列表:

package script.runByUserCF;

import java.util.List;
import java.util.Map;
import java.util.Set;

import service.DataProcessByUserCF;
import util.FileTool;
import entity.Score;


public class PredictTest {
	public static void main(String[] args) throws Exception {
		/*
		String inputDir = "data/fresh_comp_offline/sample2/out/";
		String outputDir = "data/fresh_comp_offline/sample2/out/";
		String inputPath = inputDir + "score.csv";
		String outputPath = outputDir + "predict";
		String userDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
		*/
		
		String inputDir = "data/fresh_comp_offline/sample2/out/";
		String outputDir = "data/fresh_comp_offline/sample2/out/";
		String inputPath = inputDir + "score.csv";
		String outputPath = outputDir + "predict";
		String userDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
		int topNUser = 5;
		double score_threshold = 0;
		int topNItem = 5;
		double weight_threshold = 1.0;
		
		/*
		String inputDir = args[0];
		String outputDir = args[1];
		String inputPath = inputDir + args[2];
		String outputPath = outputDir + args[3];
		String userDir = args[4];
		int topNUser = Integer.parseInt(args[5]);
		double score_threshold = Double.parseDouble(args[6]);
		int topNItem = Integer.parseInt(args[7]);
		double weight_threshold = Double.parseDouble(args[8]);
		*/
		Map<String, List<Score>> scoreMap = FileTool.loadScoreMap(inputPath, false, ",", "userCF");
		DataProcessByUserCF.sortScoreMap(scoreMap);
		List<String> fileNameList = FileTool.traverseFolder(userDir);
		//我选择推荐该user的最相似的5个user的前5个item
		Map<String, Set<String>> predictMap = DataProcessByUserCF.predict(scoreMap, fileNameList, userDir, topNUser, score_threshold, topNItem, weight_threshold);
		FileTool.initWriter1(outputPath);
		DataProcessByUserCF.outputRecommendList(predictMap);
		FileTool.closeWriter1();
		scoreMap.clear();
	}
}

基于item的协同过滤:map文件构建user-item评分矩阵并计算item间的相似度生成item-item的得分表:

package script.runByItemCF;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import service.CalculateSimilarityByItemCF;
import service.DataProcessByItemCF;
import util.FileTool;
import entity.User;

public class SpliteFileAndMakeScoreTable {
	
	public static void main(String[] args) throws Exception {
		/*
		String inputDir = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/";
		String outputDir = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/out/user/";
		String userPath = inputDir + "train_als_itemCF1.csv";
		String outputPath = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/out/score.csv";
		String scoreTablePath = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/out/scoreTable.csv";
		String token = ",";
		*/
		
		
		String inputDir = args[0];
		String outputDir = args[1];
		String userPath = inputDir + args[2];
		String outputPath = args[3];
		String scoreTablePath = args[4];
		String token = args[5];
		
		
		//FileTool.makeSampleData(userPath, true, outputPath, 10000);
		//List<Object> itemList = FileTool.readFileOne(itemPath, true, ",", "item");
		//List<User> userList = FileTool.readFileOne(userPath, true, ",", "user", "mapUser"); //sample2
		//List<User> userList = FileTool.readFileOne(userPath, true, "::", "user", "mapUser"); //movielens
		List<User> userList = FileTool.readFileOne(userPath, true, token, "user", "mapUser");
		Set<String> userSet = new HashSet<String>();
		Set<String> itemSet = new HashSet<String>();
		Map<String, Map<String, List<User>>> userMap = DataProcessByItemCF.mapByUser(userList,userSet,itemSet);
		userList.clear();
		DataProcessByItemCF.output(userMap, outputDir);
		
		//生成itemToUser的打分表
		Map<String, Map<String, Double>> scoreTable = DataProcessByItemCF.makeScoreTable(userMap);
		DataProcessByItemCF.output(scoreTable, scoreTablePath , userSet, itemSet, ",");
		userMap.clear();
		FileTool.initWriter1(outputPath);
		CalculateSimilarityByItemCF.execute(scoreTable, userSet, itemSet);
		FileTool.closeWriter1();
		
	}

}

基于item的协同过滤,reduce文件,对users排序:

package script.runByItemCF;

import java.util.Collections;
import java.util.List;

import service.DataProcessByUserCF;
import util.FileTool;
import entity.User;

public class ReduceFileTest {
	public static void main(String[] args) throws Exception {
		/*
		String inputDir = "data/fresh_comp_offline/sample2/out/user/";
		String outputDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
		*/
		/*
		String inputDir = "data/ml-1m/out/user/";
		String outputDir = "data/ml-1m/out/sorteduser/";
		*/
		
		String inputDir = args[0];
		String outputDir = args[1];
		
		
		List<String> pathList = FileTool.traverseFolder(inputDir);
		for(String path : pathList){
			List<User> userList = FileTool.readFileOne(inputDir+path, false, "\t", "user", "reduceUser");
			List<User> list = DataProcessByUserCF.reduceUserByItem(userList);
			userList.clear();
			FileTool.initWriter1(outputDir + path);
			Collections.sort(list);
			DataProcessByUserCF.outputUser(list);
			FileTool.closeWriter1();
			list.clear();
		}
	}
}

基于item的协同过滤,为用户进行推荐,生成预测列表:

package script.runByItemCF;

import java.util.List;
import java.util.Map;
import java.util.Set;

import service.DataProcessByItemCF;
import util.FileTool;
import entity.Score;


public class PredictTest {
	public static void main(String[] args) throws Exception {
		/*
		String inputDir = "data/fresh_comp_offline/sample2/out/";
		String outputDir = "data/fresh_comp_offline/sample2/out/";
		String inputPath = inputDir + "score.csv";
		String outputPath = outputDir + "predict";
		String userDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
		int topNUser = 5;
		double score_threshold = 0;
		int topNItem = 5;
		double weight_threshold = 0;
		*/
		/*
		String inputDir = "data/ml-1m/out/out_itemCF/";
		String outputDir = "data/ml-1m/out/out_itemCF/";
		String inputPath = inputDir + "score.csv";
		String outputPath = outputDir + "predict";
		String userDir = "data/ml-1m/out/out_itemCF/sorteduser/";
		int topNUser = 5;
		double score_threshold = 0;
		int topNItem = 5;
		double weight_threshold = 0;
		*/
		
		
		String inputDir = args[0];
		String outputDir = args[1];
		String inputPath = inputDir + args[2];
		String outputPath = outputDir + args[3];
		String userDir = args[4];
		int topNUser = Integer.parseInt(args[5]);
		double score_threshold = Double.parseDouble(args[6]);
		int topNItem = Integer.parseInt(args[7]);
		double weight_threshold = Double.parseDouble(args[8]);
		
		
		Map<String, List<Score>> scoreMap = FileTool.loadScoreMap(inputPath, false, ",", "itemCF");
		DataProcessByItemCF.sortScoreMap(scoreMap);
		List<String> fileNameList = FileTool.traverseFolder(userDir);
		//我选择推荐该user的最相似的5个user的前5个item
		Map<String, Set<String>> predictMap = DataProcessByItemCF.predict(scoreMap, fileNameList, userDir, topNUser, score_threshold, topNItem, weight_threshold);
		FileTool.initWriter1(outputPath);
		DataProcessByItemCF.outputRecommendList(predictMap);
		FileTool.closeWriter1();
		scoreMap.clear();
	}
}

计算准确率、召回率、F测度值:

package script;

import java.util.List;
import java.util.Map;

import service.DataProcessByUserCF;
import util.FileTool;

public class MatchTest {

	public static void main(String[] args) throws Exception {
		
		//String predictPath = "data/ml-1m/out/predict_itemCF2";
		//String predictPath = "data/ml-1m/out/predict_userCF";
		//String predictPath = "data/ml-1m/out/predict_ALS_15_600_0.388_0.01";
		//String predictPath = "data/ml-1m/out/predict_ml_ALS_15_600_0.388_0.01_1.csv";
		//String predictPath = "data/unicom/outRecommend/out/out_outRecommedList3.csv";
		//String predictPath = "data/unicom/allRecommend/out/out_outRecommedList250_0.4.csv";
		//String predictPath = "data/unicom/allRecommend/out/predict_als_randomForest.csv";
		String predictPath = "data/unicom/out/test_pred.txt";
		
		//String predictPath = "data/ml-1m/out/predict_ml_userCF_RF.csv";
		//String predictPath = "data/ml-1m/out/predict_ml_userCF.csv";
		//String predictPath = "data/ml-1m/out/predict_ml_ALS.csv";
		//String predictPath = "data/ml-1m/out/predict_ml_all1.csv";
		//String predictPath = "data/ml-1m/out/predict_ml_ALS_15_600_0.388_0.01.csv";
		
		//String referencePath = "data/ml-1m/sample/testData.dat";
		//String referencePath = "data/unicom/outRecommend/test3_list.csv";
		//String referencePath = "data/unicom/allRecommend/test_als.csv";
		String referencePath = "data/unicom/out/test_real.txt";
		
		Map<String, List<String>> predictMap = FileTool.loadPredictData(predictPath, false, ",", 50);
		int predictN = FileTool.count;
		FileTool.count = 0;
		Map<String, List<String>> referenceMap = FileTool.loadPredictData(referencePath, false, ",", 10000);
		int referenceN = FileTool.count;
		DataProcessByUserCF.prediction(predictMap, predictN, referenceMap, referenceN);
	}

}

脚本运行顺序:

(1).基于用户的协同过滤推荐:
1.script.runByUserCF.SpliteFileAndMakeScoreTable.java //map文件并生成user-user的score
2.script.runByUserCF.ReduceFileTest.java  //对map后的文件排序,主要对user内的item的score排序
3.script.runByUserCF.PredictTest.java   //预测,生成预测列表user-item
4.runMakeTestSet.sh  //生成测试集
5.runSpliteFile.sh   //对测试集文件进行map
6.runMatch.sh     //将预测列表与测试集进行匹配,计算预测准确率及召回率
(2).基于item的协同过滤推荐:
1.script.runByItemCF.SpliteFileAndMakeScoreTable.java //map文件并生成item-item的score
2.script.runByItemCF.ReduceFileTest.java  //对map后的文件排序,主要对user内的item的score排序
3.script.runByItemCF.PredictTest.java   //预测,生成预测列表user-item
4.runMakeTestSet.sh  //生成测试集
5.runSpliteFile.sh   //对测试集文件进行map
6.runMatch.sh     //将预测列表与测试集进行匹配,计算预测准确率及召回率

以上为核心代码,大家可以参考项目源代码地址:

https://download.csdn.net/download/u013473512/10141066

https://github.com/Emmitte/recommendSystem

欢迎关注“程序杂货铺”公众号,里面有精彩内容,欢迎大家收看^_^

协同过滤推荐算法(java原生JDK实现-附源码地址)

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/125710.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 计算机dll修复工具,DLL修复工具哪个好?五款修复能力强推荐

    计算机dll修复工具,DLL修复工具哪个好?五款修复能力强推荐为什么会用到dll修复工具呢?因为我们在打开某些程序或者软件的时候会提示找不到某某.dll文件,关键是这些dll文件还不一样,我们去网上下载这些dll文件结果显示跟系统的版本不一致,反正就是各种麻烦,自己去找又费时又费力,而且往往对于有些游戏来说,修补了某一个dll又提示缺少另一个dll文件,这些其实可能都是系统本身太精简或者没有安装一些依赖软件导致的,这时候你完全不需要手动去找这些dll文件,只…

  • Vue + Spring Boot 项目实战(一):项目简介

    Vue + Spring Boot 项目实战(一):项目简介白卷是一款使用Vue+SpringBoot开发的前后端分离项目,主要帮助web开发初学者通过实践方式打通各个环节的知识。

  • GoogLeNet 神经网络结构

    GoogLeNet 神经网络结构GoogLeNet是2014年ILSVRC冠军模型,top-5错误率6.7%,GoogLeNet做了更大胆的网络上的尝试而不像vgg继承了lenet以及alexnet的一些框架,该模型虽然有22层,但参数量只有AlexNet的1/12。GoogLeNet论文指出获得高质量模型最保险的做法就是增加模型的深度(层数)或者是其宽度(层核或者神经元数),但是一般情况下更深或更宽的网络

  • 无线突然变叉,无法使用[通俗易懂]

    无线突然变叉,无法使用[通俗易懂]1.查看了驱动,是正常的2.查看了控制面板,“无线”是开启的2.win + r 输入 services.msc进入策略这个三个改为自动,重启电脑

  • Ubuntu安装Redis及使用「建议收藏」

    Ubuntu安装Redis及使用「建议收藏」NoSQL简介NoSQL,全名为NotOnlySQL,指的是非关系型的数据库随着访问量的上升,网站的数据库性能出现了问题,于是nosql被设计出来优点/缺点优点:高可扩展性分布式计算低成本架构的灵活性,半结构化数据没有复杂的关系缺点:没有标准化有限的查询功能(到目前为止)最终一致是不直观的程序分类类型部分代表特点列存储H…

  • 动感地带亲情省

    动感地带亲情省注:本文转自网络为进一步满足客户国内长途及国内漫游通话需求,中国移动北京公司将于近期推出长漫亲情省业务,现就相关事宜通知如下:  一、推出时间  2009年4月18日0时  二、目标客户动感地带、

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号