大家好,又见面了,我是你们的朋友全栈君。
一、项目需求
1. 需求链接
https://tianchi.aliyun.com/getStart/information.htm?raceId=231522
2. 需求内容
竞赛题目
在真实的业务场景下,我们往往需要对所有商品的一个子集构建个性化推荐模型。在完成这件任务的过程中,我们不仅需要利用用户在这个商品子集上的行为数据,往往还需要利用更丰富的用户行为数据。定义如下的符号:
U——用户集合
I——商品全集
P——商品子集,P ⊆ I
D——用户对商品全集的行为数据集合
那么我们的目标是利用D来构造U中用户对P中商品的推荐模型。
数据说明
本场比赛提供20000用户的完整行为数据以及百万级的商品信息。竞赛数据包含两个部分。
第一部分是用户在商品全集上的移动端行为数据(D),表名为tianchi_fresh_comp_train_user_2w,包含如下字段:
字段 |
字段说明 |
提取说明 |
user_id |
用户标识 |
抽样&字段脱敏 |
item_id |
商品标识 |
字段脱敏 |
behavior_type |
用户对商品的行为类型 |
包括浏览、收藏、加购物车、购买,对应取值分别是1、2、3、4。 |
user_geohash |
用户位置的空间标识,可以为空 |
由经纬度通过保密的算法生成 |
item_category |
商品分类标识 |
字段脱敏 |
time |
行为时间 |
精确到小时级别 |
第二个部分是商品子集(P),表名为tianchi_fresh_comp_train_item_2w,包含如下字段:
字段 |
字段说明 |
提取说明 |
item_id |
商品标识 |
抽样&字段脱敏 |
item_ geohash |
商品位置的空间标识,可以为空 |
由经纬度通过保密的算法生成 |
item_category |
商品分类标识 |
字段脱敏 |
训练数据包含了抽样出来的一定量用户在一个月时间(11.18~12.18)之内的移动端行为数据(D),评分数据是这些用户在这个一个月之后的一天(12.19)对商品子集(P)的购买数据。参赛者要使用训练数据建立推荐模型,并输出用户在接下来一天对商品子集购买行为的预测结果。
评分数据格式
具体计算公式如下:参赛者完成用户对商品子集的购买预测之后,需要将结果放入指定格式的数据表(非分区表)中,要求结果表名为:tianchi_mobile_recommendation_predict.csv,且以utf-8格式编码;包含user_id和item_id两列(均为string类型),要求去除重复。例如:
评估指标
比赛采用经典的精确度(precision)、召回率(recall)和F1值作为评估指标。具体计算公式如下:
其中PredictionSet为算法预测的购买数据集合,ReferenceSet为真实的答案购买数据集合。我们以F1值作为最终的唯一评测标准。
二、协同过滤推荐算法原理及实现流程
1. 基于用户的协同过滤推荐算法
基于用户的协同过滤推荐算法通过寻找与目标用户具有相似评分的邻居用户,通过查找邻居用户喜欢的项目,推测目标用户也具有相同的喜好。基于用户的协同过滤推荐算法基本思想是:根据用户-项目评分矩阵查找当前用户的最近邻居,利用最近邻居的评分来预测当前用户对项目的预测值,将评分最高的N个项目推荐给用户,其中的项目可理解为系统处理的商品。其算法流程图如下图1所示。
图1基于用户的协同过滤推荐算法流程
基于用户的协同过滤推荐算法流程为:
1).构建用户项目评分矩阵
R={ , …… },T:m×n的用户评分矩阵,其中r={ , ,……, }为用户 的评分向量, 代表用户 对项目 的评分。
2).计算用户相似度
基于用户的协同过滤推荐算法,需查找与目标用户相似的用户。衡量用户之间的相似性需要计算每个用户的评分与其他用户评分的相似度,即评分矩阵中的用户评分记录。每个用户对项目的评分可以看作是一个n维的评分向量。使用评分向量计算目标用户 与其他用户 之间的相似度sim(i,j),通常计算用户相似度的方法有三种:余弦相似度、修正的余弦相似度和皮尔森相关系数。
3).构建最近邻居集
最近邻居集Neighor(u)中包含的是与目标用户具有相同爱好的其他用户。为选取邻居用户,我们首先计算目标用户u与其他用户v的相似度sim(u,v),再选择相似度最大的k个用户。用户相似度可理解为用户之间的信任值或推荐权重。通常,sim(u,v)∈[1,1]。用户相似度为1表示两个用户互相的推荐权重很大。如果为-1,表示两个用户的由于兴趣相差很大,因此互相的推荐权重很小。
4).预测评分计算
用户a 对项目i的预测评分p(a,i)为邻居用户对该项目评分的加权评分值。显然,不同用户对于目标用户的影响程度不同,所以在计算预测评分时,不同用户有不同的权重。计算时,我们选择用户相似度作为用户的权重因子,计算公式如下:
基于用户的协同过滤推荐算法实现步骤为:
1).实时统计user对item的打分,从而生成user-item表(即构建用户-项目评分矩阵);
2).计算各个user之间的相似度,从而生成user-user的得分表,并进行排序;
3).对每一user的item集合排序;
4).针对预推荐的user,在user-user的得分表中选择与该用户最相似的N个用户,并在user-item表中选择这N个用户中已排序好的item集合中的topM;
5).此时的N*M个商品即为该用户推荐的商品集。
2. 基于项目的协同过滤推荐算法
基于项目的协同过滤推荐算法依据用户-项目评分矩阵通过计算项目之间的评分相似性来衡量项目评分相似性,找到与目标项目最相似的n个项目作为最近邻居集。然后通过对目标项目的相似邻居赋予一定的权重来预测当前项目的评分,再将得到的最终预测评分按序排列,将评分最高的N个项目推荐给当前用户,其中的项目可理解为系统处理的商品。其算法流程如下图2所示。
图2基于项目的协同过滤推荐算法流程
基于项目的协同过滤推荐算法流程为:
首先,读取目标用户的评分记录集合 ;然后计算项目i与 中其他项目的相似度,选取k个最近邻居;根据评分相似度计算公式计算候选集中所有项目的预测评分;最后选取预测评分最高的N个项目推荐给用户。
基于项目的协同过滤推荐算法预测评分与其他用户评分的加权评分值相关,不同的历史评分项目与当前项目i的相关度有差异,所以在进行计算时,不同的项目有不同的权重。评分预测函数p(u,i),以项目相似度作为项目的权重因子,得到的评分公式如下:
基于项目的协同过滤推荐算法实现步骤为:
1).实时统计user对item的打分,从而生成user-item表(即构建用户-项目评分矩阵);
2).计算各个item之间的相似度,从而生成item-item的得分表,并进行排序;
3).对每一user的item集合排序;
4).针对预推荐的user,在该用户已选择的item集合中,根据item-item表选择与已选item最相似的N个item;
5).此时的N个商品即为该用户推荐的商品集。
3. 基于用户的协同过滤推荐算法与基于项目的协同过滤推荐算法比较
基于用户的协同过滤推荐算法:
可以帮助用户发现新的商品,但需要较复杂的在线计算,需要处理新用户的问题。
基于项目的协同过滤推荐算法:
准确性好,表现稳定可控,便于离线计算,但推荐结果的多样性会差一些,一般不会带给用户惊喜性。
三、 项目实现
针对移动推荐,我们选择使用基于用户的协同过滤推荐算法来进行实现。
1. 数据模型及其实体类
用户行为数据:(user.csv)
user_id,item_id,behavior_type,user_geohash,item_category,time
10001082,285259775,1,97lk14c,4076,2014-12-08 18
10001082,4368907,1,,5503,2014-12-12 12
10001082,4368907,1,,5503,2014-12-12 12
10001082,53616768,1,,9762,2014-12-02 15
10001082,151466952,1,,5232,2014-12-12 11
10001082,53616768,4,,9762,2014-12-02 15
10001082,290088061,1,,5503,2014-12-12 12
10001082,298397524,1,,10894,2014-12-12 12
10001082,32104252,1,,6513,2014-12-12 12
10001082,323339743,1,,10894,2014-12-1212
商品信息:(item.csv)
item_id,item_geohash,item_category
100002303,,3368
100003592,,7995
100006838,,12630
100008089,,7791
100012750,,9614
100014072,,1032
100014463,,9023
100019387,,3064
100023812,,6700
package entity;
public class Item implements Comparable<Item> {
private String itemId;
private String itemGeoHash;
private String itemCategory;
private double weight;
public String getItemId() {
return itemId;
}
public void setItemId(String itemId) {
this.itemId = itemId;
}
public String getItemGeoHash() {
return itemGeoHash;
}
public void setItemGeoHash(String itemGeoHash) {
this.itemGeoHash = itemGeoHash;
}
public String getItemCategory() {
return itemCategory;
}
public void setItemCategory(String itemCategory) {
this.itemCategory = itemCategory;
}
public double getWeight() {
return weight;
}
public void setWeight(double weight) {
this.weight = weight;
}
@Override
public String toString() {
return "Item [itemId=" + itemId + ", itemGeoHash=" + itemGeoHash
+ ", itemCategory=" + itemCategory + ", weight=" + weight + "]";
}
@Override
public int compareTo(Item o) {
return (int) (-1 * (this.weight - o.weight));
}
}
package entity;
public class Score implements Comparable<Score> {
private String userId; // 用户标识
private String itemId; // 商品标识
private double score;
public String getUserId() {
return userId;
}
public void setUserId(String userId) {
this.userId = userId;
}
public String getItemId() {
return itemId;
}
public void setItemId(String itemId) {
this.itemId = itemId;
}
public double getScore() {
return score;
}
public void setScore(double score) {
this.score = score;
}
@Override
public String toString() {
return "Score [userId=" + userId + ", itemId=" + itemId + ", score="
+ score + "]";
}
@Override
public int compareTo(Score o) {
if ((this.score - o.score) < 0) {
return 1;
}else if ((this.score - o.score) > 0) {
return -1;
}else {
return 0;
}
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((itemId == null) ? 0 : itemId.hashCode());
long temp;
temp = Double.doubleToLongBits(score);
result = prime * result + (int) (temp ^ (temp >>> 32));
result = prime * result + ((userId == null) ? 0 : userId.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Score other = (Score) obj;
if (itemId == null) {
if (other.itemId != null)
return false;
} else if (!itemId.equals(other.itemId))
return false;
if (Double.doubleToLongBits(score) != Double
.doubleToLongBits(other.score))
return false;
if (userId == null) {
if (other.userId != null)
return false;
} else if (!userId.equals(other.userId))
return false;
return true;
}
}
package entity;
public class User implements Comparable<User> {
private String userId; // 用户标识
private String itemId; // 商品标识
private double behaviorType; // 用户对商品的行为类型,可以为空,包括浏览、收藏、加购物车、购买,对应取值分别是1、2、3、4.
private String userGeoHash; // 用户位置的空间标识
private String itemCategory;// 商品分类标识
private String time; // 行为时间
private int count;
private double weight; // 权重
public String getUserId() {
return userId;
}
public void setUserId(String userId) {
this.userId = userId;
}
public String getItemId() {
return itemId;
}
public void setItemId(String itemId) {
this.itemId = itemId;
}
public double getBehaviorType() {
return behaviorType;
}
public void setBehaviorType(double behaviorType) {
this.behaviorType = behaviorType;
}
public String getUserGeoHash() {
return userGeoHash;
}
public void setUserGeoHash(String userGeoHash) {
this.userGeoHash = userGeoHash;
}
public String getItemCategory() {
return itemCategory;
}
public void setItemCategory(String itemCategory) {
this.itemCategory = itemCategory;
}
public String getTime() {
return time;
}
public void setTime(String time) {
this.time = time;
}
@Override
public String toString() {
return "User [userId=" + userId + ", itemId=" + itemId
+ ", behaviorType=" + behaviorType + ", count=" + count + "]";
}
public int getCount() {
return count;
}
public void setCount(int count) {
this.count = count;
}
public double getWeight() {
return weight;
}
public void setWeight(double weight) {
this.weight = weight;
}
@Override
public int compareTo(User o) {
return (int)((-1) * (this.weight - o.weight));
}
}
2. 工具类
文件处理工具:
package util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import entity.Item;
import entity.Score;
import entity.User;
public class FileTool {
public static FileReader fr=null;
public static BufferedReader br=null;
public static String line=null;
public static FileOutputStream fos1 = null,fos2 = null,fos3 = null;
public static PrintStream ps1 = null,ps2 = null,ps3 = null;
public static int count = 0;
/**
* 初始化写文件器(单一指针)
* */
public static void initWriter1(String writePath) {
try {
fos1 = new FileOutputStream(writePath);
ps1 = new PrintStream(fos1);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* 关闭文件器(单一指针)
* */
public static void closeReader() {
try {
br.close();
fr.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 关闭文件器(单一指针)
* */
public static void closeWriter1() {
try {
ps1.close();
fos1.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 初始化写文件器(双指针)
* */
public static void initWriter2(String writePath1,String writePath2) {
try {
fos1 = new FileOutputStream(writePath1);
ps1 = new PrintStream(fos1);
fos2 = new FileOutputStream(writePath2);
ps2 = new PrintStream(fos2);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* 关闭文件器(双指针)
* */
public static void closeWriter2() {
try {
ps1.close();
fos1.close();
ps2.close();
fos2.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 初始化写文件器(三指针)
* */
public static void initWriter3(String writePath1,String writePath2,String writePath3) {
try {
fos1 = new FileOutputStream(writePath1);
ps1 = new PrintStream(fos1);
fos2 = new FileOutputStream(writePath2);
ps2 = new PrintStream(fos2);
fos3 = new FileOutputStream(writePath3);
ps3 = new PrintStream(fos3);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* 关闭文件器(三指针)
* */
public static void closeWriter3() {
try {
ps1.close();
fos1.close();
ps2.close();
fos2.close();
ps3.close();
fos3.close();
} catch (IOException e) {
e.printStackTrace();
}
}
public static List readFileOne(String path, boolean isTitle, String token, String pattern, String process) throws Exception {
List<Object> ret = new ArrayList<Object>();
fr = new FileReader(path);
br = new BufferedReader(fr);
int count = 0,i = 0;
if (isTitle) {
line = br.readLine();
count++;
}
while((line = br.readLine()) != null){
String[] strArr = line.split(token);
switch (pattern) {
case "item":
ret.add(ParseTool.parseItem(strArr));
break;
case "user":
ret.add(ParseTool.parseUser(strArr, process));
break;
default:
ret.add(line);
break;
}
count++;
if (count/100000 == 1) {
i++;
System.out.println(100000*i);
count = 0;
}
}
closeReader();
return ret;
}
public static void makeSampleData(String inputPath,boolean isTitle,String outputPath,int threshold) throws Exception {
fr = new FileReader(inputPath);
br = new BufferedReader(fr);
initWriter1(outputPath);
if (isTitle) {
line = br.readLine();
}
int count = 0;
while((line = br.readLine()) != null){
ps1.println(line);
count++;
if (count == threshold) {
break;
}
}
closeReader();
}
public static List<String> traverseFolder(String dir) {
File file = new File(dir);
String[] fileList = null;
if (file.exists()) {
fileList = file.list();
}
List<String> list = new ArrayList<String>();
for(String path : fileList){
list.add(path);
}
return list;
}
public static Map<String, List<Score>> loadScoreMap(String path, boolean isTitle, String token, String type) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
Map<String, List<Score>> scoreMap = new HashMap<String, List<Score>>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
Score score = null;
switch (type) {
case "userCF":
score = ParseTool.parseScoreByItemCF(arr);
break;
case "itemCF":
score = ParseTool.parseScoreByUserCF(arr);
break;
default:
break;
}
List<Score> temp = new ArrayList<Score>();
if (scoreMap.containsKey(score.getItemId())) {
temp = scoreMap.get(score.getItemId());
}
temp.add(score);
scoreMap.put(score.getItemId(), temp);
}
closeReader();
return scoreMap;
}
public static Map<String, List<String>> loadPredictData(String path, boolean isTitle, String token, int n) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
Map<String, List<String>> map = new HashMap<String, List<String>>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
String userId = arr[0];
String itemId = arr[1];
List<String> temp = new ArrayList<String>();
if (map.containsKey(userId)) {
temp = map.get(userId);
}
if (!temp.contains(itemId) && temp.size() <= n) {
temp.add(itemId);
map.put(userId, temp);
count++;
}
}
closeReader();
return map;
}
public static Map<String, List<String>> loadTestData(Map<String, List<String>> predictMap, String dir, boolean isTitle, String token) throws Exception {
List<String> fileList = traverseFolder(dir);
Set<String> predictKeySet = predictMap.keySet();
Map<String, List<String>> testMap = new HashMap<String, List<String>>();
for(String predictKey : predictKeySet){
if (fileList.contains(predictKey)) {
List<String> itemList = getIdList(dir + predictKey, isTitle, token);
testMap.put(predictKey, itemList);
}
}
return testMap;
}
public static List<String> getIdList(String path, boolean isTitle, String token) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
List<String> list = new ArrayList<String>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
if (!list.contains(arr[0].trim())) {
list.add(arr[0].trim());
}
count++;
}
closeReader();
return list;
}
public static Map<String, Double> loadUser_ItemData(String path,boolean isTitle,String token) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
Map<String, Double> map = new HashMap<String, Double>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
String itemId = arr[1];
double score = Double.valueOf(arr[2]);
if(map.containsKey(itemId)){
double temp = map.get(itemId);
if (temp > score) {
score = temp;
}
}
map.put(itemId, score);
}
closeReader();
return map;
}
public static void makeSampleData(String path, boolean isTitle,String token, List<String> userList, List<String> itemList) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
while((line = br.readLine()) != null){
String[] arr = line.split(token);
String userId = arr[0];
String itemId = arr[1];
if (userList.contains(userId) && itemList.contains(itemId)) {
FileTool.ps1.println(line);
}
}
closeReader();
}
}
解析工具:
package util;
import entity.Item;
import entity.Score;
import entity.User;
public class ParseTool {
public static boolean isNumber(String str) {
int i,n;
n = str.length();
for(i = 0;i < n;i++){
if (!Character.isDigit(str.charAt(i))) {
return false;
}
}
return true;
}
public static Item parseItem(String[] contents) {
Item item = new Item();
if (contents[0] != null && !contents[0].isEmpty()) {
item.setItemId(contents[0].trim());
}
if (contents[1] != null && !contents[1].isEmpty()) {
item.setItemGeoHash(contents[1].trim());
}
if (contents[2] != null && !contents[2].isEmpty()) {
item.setItemCategory(contents[2].trim());
}
return item;
}
public static User parseUser(String[] contents, String type) {
User user = new User();
int n = contents.length;
if (contents[0] != null && !contents[0].isEmpty()) {
user.setUserId(contents[0].trim());
}
if (contents[1] != null && !contents[1].isEmpty()) {
user.setItemId(contents[1].trim());
}
if ("mapUser".equals(type)) {
// 1.调用SpliteFileAndMakeScoreTable需放开,其它需注释
if (contents[2] != null && !contents[2].isEmpty()) {
user.setBehaviorType(Double.valueOf(contents[2].trim()));
}
/*
// sample2
if (contents[3] != null && !contents[3].isEmpty()) {
user.setUserGeoHash(contents[3].trim());
}
if (contents[4] != null && !contents[4].isEmpty()) {
user.setItemCategory(contents[4].trim());
}
if (contents[5] != null && !contents[5].isEmpty()) {
user.setTime(contents[5].trim());
}
*/
// movielens
if (contents[3] != null && !contents[3].isEmpty()) {
user.setTime(contents[3].trim());
}
}
if ("reduceUser".equals(type)) {
// 2.调用ReducefileTest需放开,其它需注释
if (contents[2] != null && !contents[2].isEmpty()) {
user.setBehaviorType(Double.parseDouble(contents[2].trim()));
}
if (contents[n-1] != null && !contents[n-1].isEmpty()) {
user.setCount(Integer.valueOf(contents[n-1].trim()));
}
}
if ("predict".equals(type)) {
// 3.调用PredictTest需放开,其它需注释
if (contents[n-1] != null && !contents[n-1].isEmpty()) {
user.setWeight(Double.valueOf(contents[n-1].trim()));
}
}
return user;
}
public static Score parseScoreByUserCF(String[] contents) {
Score score = new Score();
if (contents[0] != null && !contents[0].isEmpty()) {
score.setUserId(contents[0].trim());
}
if (contents[1] != null && !contents[1].isEmpty()) {
score.setItemId(contents[1].trim());
}
if (contents[2] != null && !contents[2].isEmpty() && !"Infinity".equals(contents[2])) {
double score_temp = Double.parseDouble(contents[2].trim());
if (score_temp < 0.0001) {
score_temp = 0;
}
score.setScore(score_temp);
}
return score;
}
public static Score parseScoreByItemCF(String[] contents) {
Score score = new Score();
if (contents[0] != null && !contents[0].isEmpty()) {
score.setItemId(contents[0].trim());
}
if (contents[1] != null && !contents[1].isEmpty()) {
score.setUserId(contents[1].trim());
}
if (contents[2] != null && !contents[2].isEmpty() && !"Infinity".equals(contents[2])) {
double score_temp = Double.parseDouble(contents[2].trim());
if (score_temp < 0.0001) {
score_temp = 0;
}
score.setScore(score_temp);
}
return score;
}
}
3. 数据处理模块:
基于用户的协同过滤数据处理模块:
package service;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
import util.FileTool;
import entity.Item;
import entity.Score;
import entity.User;
public class DataProcessByUserCF {
public static final double[] w = {0,10,20,30};
public static void output(Map<String, Map<String, List<User>>> userMap,String outputPath) {
for(Entry<String, Map<String, List<User>>> entry : userMap.entrySet()){
FileTool.initWriter1(outputPath + entry.getKey());
Map<String, List<User>> temp = entry.getValue();
for(Entry<String, List<User>> tempEntry : temp.entrySet()){
List<User> users = tempEntry.getValue();
int count = users.size();
for(User user : users){
FileTool.ps1.print(user.getUserId() + "\t");
FileTool.ps1.print(user.getItemId() + "\t");
FileTool.ps1.print(user.getBehaviorType() + "\t");
//FileTool.ps1.print(user.getUserGeoHash() + "\t");
//FileTool.ps1.print(user.getItemCategory() + "\t");
//FileTool.ps1.print(user.getTime() + "\t");
FileTool.ps1.print(count + "\n");
}
}
}
FileTool.closeWriter1();
}
public static void output(Map<String, Map<String, Double>> scoreTable, String outputPath, Set<String> userSet, Set<String> itemSet, String token) {
FileTool.initWriter1(outputPath);
for(String itemId: itemSet){
FileTool.ps1.print(token + itemId);
}
FileTool.ps1.println();
for(String userId : userSet){
FileTool.ps1.print(userId + token);
Map<String, Double> itemMap = scoreTable.get(userId);
for(String itemId: itemSet){
if(itemMap.containsKey(itemId)){
FileTool.ps1.print(itemMap.get(itemId));
}else {
//FileTool.ps1.print(0);
}
FileTool.ps1.print(token);
}
FileTool.ps1.print("\n");
}
}
public static void outputUser(List<User> userList) {
for(User user : userList){
FileTool.ps1.println(user.getUserId() + "\t" + user.getItemId() + "\t" + user.getWeight());
}
}
public static void outputScore(List<Score> scoreList) {
for(Score score : scoreList){
FileTool.ps1.println(score.getUserId() + "\t" + score.getItemId() + "\t" + score.getScore());
}
}
public static void outputRecommendList(Map<String, Set<String>> map) {
for(Entry<String, Set<String>> entry : map.entrySet()){
String userId = entry.getKey();
Set<String> itemSet = entry.getValue();
for(String itemId : itemSet){
FileTool.ps1.println(userId + "," + itemId);
}
}
}
public static void output(Map<String, Set<String>> map) {
for(Entry<String, Set<String>> entry : map.entrySet()){
String userId = entry.getKey();
Set<String> set = entry.getValue();
for(String itemId : set){
FileTool.ps1.println(userId + "\t" + itemId);
}
}
}
public static Map<String, Map<String, List<User>>> mapByUser(List<User> userList,Set<String> userSet,Set<String> itemSet) {
Map<String, Map<String, List<User>>> userMap = new HashMap<>();
for(User user: userList){
Map<String, List<User>> tempMap = new HashMap<String, List<User>>();
List<User> tempList = new ArrayList<User>();
if (!userMap.containsKey(user.getUserId())) {
}else {
tempMap = userMap.get(user.getUserId());
if (!tempMap.containsKey(user.getItemId())) {
}else {
tempList = tempMap.get(user.getItemId());
}
}
tempList.add(user);
tempMap.put(user.getItemId(), tempList);
userMap.put(user.getUserId(), tempMap);
userSet.add(user.getUserId());
itemSet.add(user.getItemId());
}
return userMap;
}
public static Map<String, Map<String, Double>> makeScoreTable(Map<String, Map<String, List<User>>> userMap) {
Map<String, Map<String, Double>> scoreTable = new HashMap<String, Map<String,Double>>();
for(Entry<String, Map<String, List<User>>> userEntry : userMap.entrySet()){
Map<String, List<User>> itemMap = userEntry.getValue();
String userId = userEntry.getKey();
Map<String, Double> itemScoreMap = new HashMap<String, Double>();
for(Entry<String, List<User>> itemEntry : itemMap.entrySet()){
String itemId = itemEntry.getKey();
List<User> users = itemEntry.getValue();
double weight = 0.0;
double maxType = 0;
for(User user : users){
if (user.getBehaviorType() > maxType) {
maxType = user.getBehaviorType();
}
}
int count = users.size();
if (maxType != 0) {
//weight += w[(int)maxType-1]; //sample2
weight += (maxType-1) * 10; //movielens
}
weight += count;
itemScoreMap.put(itemId, weight);
}
scoreTable.put(userId, itemScoreMap);
}
return scoreTable;
}
public static double calculateWeight(double behaviorType, int count) {
//double weight = w[(int)(behaviorType)-1] + count; //sample2
double weight = (behaviorType - 1) * 10 + count; //movielens
return weight;
}
public static List<User> reduceUserByItem(List<User> userList) {
List<User> list = new ArrayList<User>();
Map<String, User> userMap = new LinkedHashMap<String, User>();
for(User user : userList){
List<String> itemList = new ArrayList<String>();
String itemId = user.getItemId();
if (!userMap.containsKey(itemId)) {
double weight = calculateWeight(user.getBehaviorType(), user.getCount());
user.setWeight(weight);
userMap.put(itemId, user);
list.add(user);
}else {
User temp = userMap.get(itemId);
double weight = calculateWeight(user.getBehaviorType(), user.getCount());
if (temp.getWeight() < weight) {
user.setWeight(weight);
userMap.put(itemId, user);
list.remove(temp);
list.add(user);
}
}
}
userMap.clear();
return list;
}
public static void sortScoreMap(Map<String, List<Score>> scoreMap) {
Set<String> userSet = scoreMap.keySet();
for(String userId : userSet){
List<Score> temp = scoreMap.get(userId);
Collections.sort(temp);
scoreMap.put(userId, temp);
}
}
public static Map<String, Set<String>> predict(Map<String, List<Score>> scoreMap, List<String> fileNameList, String userDir, int topNUser, double score_threshold, int topNItem, double weight_threshold) throws Exception {
Map<String, Set<String>> recommendList = new HashMap<String, Set<String>>();
for(Entry<String, List<Score>> entry : scoreMap.entrySet()){
String userId1 = entry.getKey();
List<Score> list = entry.getValue();
int countUser = 0;
Set<String> predictItemSet = new LinkedHashSet<String>();
for(Score score : list){
if (score.getScore() <= score_threshold) {
break;
}
String userId2 = score.getUserId();
if(fileNameList.contains(userId2)){
List<User> userList = FileTool.readFileOne(userDir + userId2, false, "\t", "user", "predict");
int countItem = 0;
for(User user : userList){
if (user.getWeight() <= weight_threshold) {
continue;
}
predictItemSet.add(user.getItemId());
countItem++;
if (countItem == topNItem) {
break;
}
}
countUser++;
}
if (countUser == topNUser) {
break;
}
}
recommendList.put(userId1, predictItemSet);
}
return recommendList;
}
public static void prediction(Map<String, List<String>> predictMap,int predictN, Map<String, List<String>> referenceMap, int refN) {
for(int i = 1;i <= 10;i++){
int count = 0;
predictN = 0;
for(Entry<String, List<String>> predictEntity : predictMap.entrySet()){
String userId = predictEntity.getKey();
if (referenceMap.containsKey(userId)) {
List<String> predictList = predictEntity.getValue();
int j = 0;
for(String itemId : predictList){
predictN++;
if (referenceMap.get(userId).contains(itemId)) {
count++;
}
j++;
if (j == i) {
break;
}
}
}
}
double precision = (1.0 * count / predictN) * 100;
double recall = (1.0 * count / refN) * 100;
if (recall > 100) {
recall = 100;
}
double f1 = (2 * precision * recall)/(precision + recall);
//System.out.println(predictN);
//System.out.println(refN);
System.out.println("推荐个数:" + i + ",中标个数:" + count + ",推荐总个数:" + predictN + ",真实购买个数:" + refN + ",precision:" + count + "/" + predictN + "=" + precision + "%,recall:" + count + "/" + refN + "=" + recall + "%,f1:" + f1 + "%");
}
}
}
基于item的协同过滤数据处理模块:
package service;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
import util.FileTool;
import entity.Item;
import entity.Score;
import entity.User;
public class DataProcessByItemCF {
public static final double[] w = {0,10,20,30};
public static void output(Map<String, Map<String, List<User>>> userMap,String outputPath) {
for(Entry<String, Map<String, List<User>>> entry : userMap.entrySet()){
FileTool.initWriter1(outputPath + entry.getKey());
Map<String, List<User>> temp = entry.getValue();
for(Entry<String, List<User>> tempEntry : temp.entrySet()){
List<User> users = tempEntry.getValue();
int count = users.size();
for(User user : users){
FileTool.ps1.print(user.getUserId() + "\t");
FileTool.ps1.print(user.getItemId() + "\t");
FileTool.ps1.print(user.getBehaviorType() + "\t");
//FileTool.ps1.print(user.getUserGeoHash() + "\t");
//FileTool.ps1.print(user.getItemCategory() + "\t");
//FileTool.ps1.print(user.getTime() + "\t");
FileTool.ps1.print(count + "\n");
}
}
FileTool.closeWriter1();
}
}
public static void output(Map<String, Map<String, Double>> scoreTable, String outputPath, Set<String> userSet, Set<String> itemSet, String token) {
FileTool.initWriter1(outputPath);
for(String userId: userSet){
FileTool.ps1.print(token + userId);
}
FileTool.ps1.println();
for(String itemId : itemSet){
FileTool.ps1.print(itemId + token);
Map<String, Double> userMap = scoreTable.get(itemId);
for(String userId: userSet){
if(userMap.containsKey(userId)){
FileTool.ps1.print(userMap.get(userId));
}else {
//FileTool.ps1.print(0);
}
FileTool.ps1.print(token);
}
FileTool.ps1.print("\n");
}
FileTool.closeWriter1();
}
public static void outputUser(List<User> userList) {
for(User user : userList){
FileTool.ps1.println(user.getUserId() + "\t" + user.getItemId() + "\t" + user.getWeight());
}
}
public static void outputScore(List<Score> scoreList) {
for(Score score : scoreList){
FileTool.ps1.println(score.getUserId() + "\t" + score.getItemId() + "\t" + score.getScore());
}
}
public static void outputRecommendList(Map<String, Set<String>> map) {
for(Entry<String, Set<String>> entry : map.entrySet()){
String userId = entry.getKey();
Set<String> itemSet = entry.getValue();
for(String itemId : itemSet){
FileTool.ps1.println(userId + "," + itemId);
}
}
}
public static void output(Map<String, Set<String>> map) {
for(Entry<String, Set<String>> entry : map.entrySet()){
String userId = entry.getKey();
Set<String> set = entry.getValue();
for(String itemId : set){
FileTool.ps1.println(userId + "\t" + itemId);
}
}
}
public static Map<String, Map<String, List<User>>> mapByUser(List<User> userList, Set<String> userSet, Set<String> itemSet) {
Map<String, Map<String, List<User>>> userMap = new HashMap<>();
for(User user: userList){
Map<String, List<User>> tempMap = new HashMap<String, List<User>>();
List<User> tempList = new ArrayList<User>();
if (!userMap.containsKey(user.getUserId())) {
}else {
tempMap = userMap.get(user.getUserId());
if (!tempMap.containsKey(user.getItemId())) {
}else {
tempList = tempMap.get(user.getItemId());
}
}
tempList.add(user);
tempMap.put(user.getItemId(), tempList);
userMap.put(user.getUserId(), tempMap);
//userMap.put(user.getItemId(), tempMap);
userSet.add(user.getUserId());
itemSet.add(user.getItemId());
}
return userMap;
}
public static Map<String, Map<String, Double>> makeScoreTable(Map<String, Map<String, List<User>>> userMap) {
Map<String, Map<String, Double>> scoreTable = new LinkedHashMap<>();
for(Entry<String, Map<String, List<User>>> userEntry : userMap.entrySet()){
Map<String, List<User>> itemMap = userEntry.getValue();
String userId = userEntry.getKey();
Map<String, Double> userScoreMap;
for(Entry<String, List<User>> itemEntry : itemMap.entrySet()){
String itemId = itemEntry.getKey();
List<User> users = itemEntry.getValue();
double weight = 0.0;
double maxType = 0;
for(User user : users){
if (user.getBehaviorType() > maxType) {
maxType = user.getBehaviorType();
}
}
int count = users.size();
if (maxType != 0) {
//weight += w[(int)maxType-1]; //sample2
weight += (maxType-1) * 10; //movielens
}
weight += count;
if (scoreTable.containsKey(itemId)) {
userScoreMap = scoreTable.get(itemId);
}else {
userScoreMap = new LinkedHashMap<String, Double>();
}
userScoreMap.put(userId, weight);
scoreTable.put(itemId, userScoreMap);
}
}
return scoreTable;
}
public static double calculateWeight(double behaviorType, int count) {
//double weight = w[(int)(behaviorType)-1] + count; //sample2
double weight = (behaviorType - 1) * 10 + count; //movielens
return weight;
}
public static List<User> reduceUserByItem(List<User> userList) {
List<User> list = new ArrayList<User>();
Map<String, User> userMap = new LinkedHashMap<String, User>();
for(User user : userList){
List<String> itemList = new ArrayList<String>();
String itemId = user.getItemId();
if (!userMap.containsKey(itemId)) {
double weight = calculateWeight(user.getBehaviorType(), user.getCount());
user.setWeight(weight);
userMap.put(itemId, user);
list.add(user);
}else {
User temp = userMap.get(itemId);
double weight = calculateWeight(user.getBehaviorType(), user.getCount());
if (temp.getWeight() < weight) {
user.setWeight(weight);
userMap.put(itemId, user);
list.remove(temp);
list.add(user);
}
}
}
userMap.clear();
return list;
}
public static void sortScoreMap(Map<String, List<Score>> scoreMap) {
Set<String> userSet = scoreMap.keySet();
for(String userId : userSet){
List<Score> temp = scoreMap.get(userId);
Collections.sort(temp);
scoreMap.put(userId, temp);
}
}
public static Map<String, Set<String>> predict(Map<String, List<Score>> scoreMap, List<String> fileNameList, String userDir, int topNUser, double score_threshold, int topNItem, double weight_threshold) throws Exception {
Map<String, Set<String>> recommendList = new HashMap<String, Set<String>>();
for(String userFileName : fileNameList){
//System.out.println("userFileName:"+userFileName);
List<User> userList = FileTool.readFileOne(userDir + userFileName, false, "\t", "user", "predict");
//System.out.println("userList:"+userList);
for(User user : userList){
String userId = user.getUserId();
if (user.getWeight() <= weight_threshold) {
continue;
}
String itemId = user.getItemId();
List<Score> itemSimilarList = scoreMap.get(itemId);
//System.out.println("itemSimilarList:"+itemSimilarList);
Set<String> predictItemSet;
if (recommendList.containsKey(userId)) {
predictItemSet = recommendList.get(userId);
}else {
predictItemSet = new LinkedHashSet<String>();
}
int countItem = 0;
for(Score similarItem : itemSimilarList){
predictItemSet.add(similarItem.getUserId());
if (countItem == topNItem) {
break;
}
countItem++;
}
recommendList.put(userId, predictItemSet);
}
}
return recommendList;
}
public static void prediction(Map<String, List<String>> predictMap,int predictN, Map<String, List<String>> referenceMap, int refN) {
int count = 0;
for(Entry<String, List<String>> predictEntity : predictMap.entrySet()){
String userId = predictEntity.getKey();
if (referenceMap.containsKey(userId)) {
List<String> predictList = predictEntity.getValue();
for(String itemId : predictList){
if (referenceMap.get(userId).contains(itemId)) {
count++;
}
}
}
}
double precision = (1.0 * count / predictN) * 100;
double recall = (1.0 * count / refN) * 100;
double f1 = (2 * precision * recall)/(precision + recall);
System.out.println("precision="+precision+",recall="+recall+",f1="+f1);
}
}
4. 计算模块
基于用户的协同过滤计算相似度模块
package service;
import java.text.DecimalFormat;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import util.FileTool;
public class CalculateSimilarityByUserCF {
public static DecimalFormat df = new DecimalFormat("#.0000");
//计算方式有问题,如(a):1,0,1与1,20,1计算相似度和(b):0,1,0与0,0,0,(b)反而比(a)近
public static double EuclidDist(Map<String, Double> userMap1,
Map<String, Double> userMap2, Set<String> userSet,
Set<String> itemSet) {
double sum = 0;
for (String itemId : itemSet) {
double score1 = 0.0;
double score2 = 0.0;
if (userMap1.containsKey(itemId) && userMap2.containsKey(itemId)) {
score1 = userMap1.get(itemId);
score2 = userMap2.get(itemId);
} else if (userMap1.containsKey(itemId)) {
score1 = userMap1.get(itemId);
} else if (userMap2.containsKey(itemId)) {
score2 = userMap2.get(itemId);
}
double temp = Math.pow((score1 - score2), 2);
sum += temp;
}
sum = Math.sqrt(sum);
return sum;
}
public static double CosineDist(Map<String, Double> userMap1,
Map<String, Double> userMap2, Set<String> userSet,
Set<String> itemSet) {
double dist = 0;
double numerator = 1; // 分子
double denominator1 = 0; // 分母
double denominator2 = 0; // 分母
for (String itemId : itemSet) {
double score1 = 0.0;
double score2 = 0.0;
if (userMap1.containsKey(itemId) && userMap2.containsKey(itemId)) {
//numerator++;
score1 = userMap1.get(itemId);
score2 = userMap2.get(itemId);
numerator = numerator * (score1 * score2);
} else if (userMap1.containsKey(itemId)) {
score1 = userMap1.get(itemId);
} else if (userMap2.containsKey(itemId)) {
score2 = userMap2.get(itemId);
}
denominator1 += Math.pow(score1, 2);
denominator2 += Math.pow(score2, 2);
}
dist = ((1.0 * numerator) / (Math.sqrt(denominator1) * Math.sqrt(denominator2)));
return dist;
}
public static double execute(Map<String, Double> userMap1,Map<String, Double> userMap2,Set<String> userSet,Set<String> itemSet) {
/*
double dist = EuclidDist(userMap1, userMap2, userSet, itemSet);
double userScore = 1.0 / (1.0 + dist);
*/
double userScore = CosineDist(userMap1, userMap2, userSet, itemSet);
return userScore;
}
public static void execute(String userId,Map<String, Map<String, Double>> scoreTable,
Set<String> userSet, Set<String> itemSet) {
for (Entry<String, Map<String, Double>> userEntry : scoreTable.entrySet()) {
String userId2 = userEntry.getKey();
Map<String, Double> userMap2 = userEntry.getValue();
//double dist = EuclidDist(scoreTable.get(userId), userMap2, userSet, itemSet);
//double userScore = 1.0 / (1.0 + dist);
double userScore = CosineDist(scoreTable.get(userId), userMap2, userSet, itemSet);
FileTool.ps1.println(userId + "," + userId2 + "," + df.format(userScore));
}
}
public static void execute(Map<String, Map<String, Double>> scoreTable,
Set<String> userSet, Set<String> itemSet) {
int count = 0;
for (Entry<String, Map<String, Double>> userEntry1 : scoreTable.entrySet()) {
long startTime = System.currentTimeMillis();
String userId = userEntry1.getKey();
execute(userId, scoreTable, userSet, itemSet);
count++;
long endTime = System.currentTimeMillis();
long dur = endTime - startTime;
System.out.println("user count:" + count + ",dur time:" + dur + "," + (dur/1000) + "s," + (dur/1000/60) + "m");
}
}
}
基于item的协同过滤计算相似度模块
package service;
import java.text.DecimalFormat;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import util.FileTool;
public class CalculateSimilarityByItemCF {
public static DecimalFormat df = new DecimalFormat("#.0000");
//计算方式有问题,如(a):1,0,1与1,20,1计算相似度和(b):0,1,0与0,0,0,(b)反而比(a)近
public static double EuclidDist(Map<String, Double> userMap1,
Map<String, Double> userMap2, Set<String> userSet,
Set<String> itemSet) {
double sum = 0;
for (String userId : userSet) {
double score1 = 0.0;
double score2 = 0.0;
if (userMap1.containsKey(userId) && userMap2.containsKey(userId)) {
score1 = userMap1.get(userId);
score2 = userMap2.get(userId);
} else if (userMap1.containsKey(userId)) {
score1 = userMap1.get(userId);
} else if (userMap2.containsKey(userId)) {
score2 = userMap2.get(userId);
}
double temp = Math.pow((score1 - score2), 2);
sum += temp;
}
sum = Math.sqrt(sum);
return sum;
}
public static double CosineDist(Map<String, Double> userMap1,
Map<String, Double> userMap2, Set<String> userSet,
Set<String> itemSet) {
double dist = 0;
double numerator = 1; // 分子
double denominator1 = 0; // 分母
double denominator2 = 0; // 分母
for (String userId : userSet) {
double score1 = 0.0;
double score2 = 0.0;
if (userMap1.containsKey(userId) && userMap2.containsKey(userId)) {
//numerator++;
score1 = userMap1.get(userId);
score2 = userMap2.get(userId);
numerator = numerator * (score1 * score2);
} else if (userMap1.containsKey(userId)) {
score1 = userMap1.get(userId);
} else if (userMap2.containsKey(userId)) {
score2 = userMap2.get(userId);
}
denominator1 += Math.pow(score1, 2);
denominator2 += Math.pow(score2, 2);
}
dist = ((1.0 * numerator) / (Math.sqrt(denominator1) * Math.sqrt(denominator2)));
return dist;
}
public static double SimpleDist(Map<String, Double> userMap1,
Map<String, Double> userMap2, Set<String> userSet,
Set<String> itemSet) {
double dist = 0;
double numerator = 0; // 分子
double denominator1 = 0; // 分母
double denominator2 = 0; // 分母
for (String userId : userSet) {
if (userMap1.containsKey(userId) && userMap2.containsKey(userId)) {
numerator++;
denominator1++;
denominator2++;
} else if (userMap1.containsKey(userId)) {
denominator1++;
} else if (userMap2.containsKey(userId)) {
denominator2++;
}
}
dist = ((1.0 * numerator) / (Math.sqrt(denominator1) * Math.sqrt(denominator2)));
return dist;
}
public static double execute(Map<String, Double> userMap1,Map<String, Double> userMap2,Set<String> userSet,Set<String> itemSet) {
/*
double dist = EuclidDist(userMap1, userMap2, userSet, itemSet);
double userScore = 1.0 / (1.0 + dist);
*/
double userScore = CosineDist(userMap1, userMap2, userSet, itemSet);
return userScore;
}
public static void execute(String itemId, Map<String, Map<String, Double>> scoreTable,
Set<String> userSet, Set<String> itemSet) {
for (Entry<String, Map<String, Double>> itemEntry : scoreTable.entrySet()) {
String itemId2 = itemEntry.getKey();
Map<String, Double> userMap2 = itemEntry.getValue();
//double dist = EuclidDist(scoreTable.get(itemId), userMap2, userSet, itemSet);
//double userScore = 1.0 / (1.0 + dist);
//double userScore = CosineDist(scoreTable.get(itemId), userMap2, userSet, itemSet);
double userScore = SimpleDist(scoreTable.get(itemId), userMap2, userSet, itemSet);
FileTool.ps1.println(itemId + "," + itemId2 + "," + df.format(userScore));
}
}
public static void execute(Map<String, Map<String, Double>> scoreTable,
Set<String> userSet, Set<String> itemSet) {
int count = 0;
for (Entry<String, Map<String, Double>> itemEntry : scoreTable.entrySet()) {
long startTime = System.currentTimeMillis();
String itemId = itemEntry.getKey();
execute(itemId, scoreTable, userSet, itemSet);
count++;
long endTime = System.currentTimeMillis();
long dur = endTime - startTime;
System.out.println("item count:" + count + ",dur time:" + dur + "," + (dur/1000) + "s," + (dur/1000/60) + "m");
}
}
}
5. 脚本
生成userset和itemset:
package script;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import entity.User;
import util.FileTool;
public class MakeSet {
public static void main(String[] args) throws Exception {
String inputDir = args[0];
String outputDir = args[1];
Set<String> userSet = new HashSet<String>();
Set<String> itemSet = new HashSet<String>();
List<String> pathList = FileTool.traverseFolder(inputDir);
for(String path : pathList){
String inputPath = inputDir + path;
List<User> list = FileTool.readFileOne(inputPath, false, "\t", "user", "reduceUser");
for(User user : list){
userSet.add(user.getUserId());
itemSet.add(user.getItemId());
}
}
FileTool.initWriter1(outputDir+"userSet");
for(String userId : userSet){
FileTool.ps1.println(userId);
}
FileTool.closeWriter1();
FileTool.initWriter1(outputDir+"itemSet");
for(String itemId : itemSet){
FileTool.ps1.println(itemId);
}
FileTool.closeWriter1();
}
}
基于用户的协同过滤:map文件构建user-item评分矩阵并计算user间的相似度生成user-user的得分表:
package script.runByUserCF;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import service.CalculateSimilarityByUserCF;
import service.DataProcessByUserCF;
import util.FileTool;
import entity.User;
public class SpliteFileAndMakeScoreTable {
public static void main(String[] args) throws Exception {
String inputDir = "data/fresh_comp_offline/sample2/";
String outputDir = "data/fresh_comp_offline/sample2/out/user/";
String userPath = inputDir + "train_user.csv";
String outputPath = "data/fresh_comp_offline/sample2/out/score.csv";
String scoreTablePath = "data/fresh_comp_offline/sample2/out/scoreTable.csv";
/*
String inputDir = "data/ml-1m/";
String outputDir = "data/ml-1m/out/user/";
String userPath = inputDir + "ratings.dat";
String outputPath = "data/ml-1m/out/score.csv";
*/
/*
String inputDir = args[0];
String outputDir = args[1];
String userPath = inputDir + args[2];
String outputPath = args[3];
String scoreTablePath = args[4];
*/
//FileTool.makeSampleData(userPath, true, outputPath, 10000);
//List<Object> itemList = FileTool.readFileOne(itemPath, true, ",", "item");
List<User> userList = FileTool.readFileOne(userPath, true, ",", "user", "mapUser"); //sample2
//List<User> userList = FileTool.readFileOne(userPath, true, "::", "user", "mapUser"); //movielens
Set<String> userSet = new HashSet<String>();
Set<String> itemSet = new HashSet<String>();
Map<String, Map<String, List<User>>> userMap = DataProcessByUserCF.mapByUser(userList,userSet,itemSet);
userList.clear();
DataProcessByUserCF.output(userMap, outputDir);
//生成userToItem的打分表
Map<String, Map<String, Double>> scoreTable = DataProcessByUserCF.makeScoreTable(userMap);
DataProcessByUserCF.output(scoreTable, scoreTablePath , userSet, itemSet, ",");
userMap.clear();
FileTool.initWriter1(outputPath);
CalculateSimilarityByUserCF.execute(scoreTable, userSet, itemSet);
FileTool.closeWriter1();
}
}
基于用户的协同过滤,reduce文件,对users排序:
package script.runByUserCF;
import java.util.Collections;
import java.util.List;
import service.DataProcessByUserCF;
import util.FileTool;
import entity.User;
public class ReduceFileTest {
public static void main(String[] args) throws Exception {
String inputDir = "data/fresh_comp_offline/sample2/out/user/";
String outputDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
/*
String inputDir = "data/ml-1m/out/user/";
String outputDir = "data/ml-1m/out/sorteduser/";
*/
/*
String inputDir = args[0];
String outputDir = args[1];
*/
List<String> pathList = FileTool.traverseFolder(inputDir);
for(String path : pathList){
List<User> userList = FileTool.readFileOne(inputDir+path, false, "\t", "user", "reduceUser");
List<User> list = DataProcessByUserCF.reduceUserByItem(userList);
userList.clear();
FileTool.initWriter1(outputDir + path);
Collections.sort(list);
DataProcessByUserCF.outputUser(list);
FileTool.closeWriter1();
list.clear();
}
}
}
基于用户的协同过滤,为用户进行推荐,生成预测列表:
package script.runByUserCF;
import java.util.List;
import java.util.Map;
import java.util.Set;
import service.DataProcessByUserCF;
import util.FileTool;
import entity.Score;
public class PredictTest {
public static void main(String[] args) throws Exception {
/*
String inputDir = "data/fresh_comp_offline/sample2/out/";
String outputDir = "data/fresh_comp_offline/sample2/out/";
String inputPath = inputDir + "score.csv";
String outputPath = outputDir + "predict";
String userDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
*/
String inputDir = "data/fresh_comp_offline/sample2/out/";
String outputDir = "data/fresh_comp_offline/sample2/out/";
String inputPath = inputDir + "score.csv";
String outputPath = outputDir + "predict";
String userDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
int topNUser = 5;
double score_threshold = 0;
int topNItem = 5;
double weight_threshold = 1.0;
/*
String inputDir = args[0];
String outputDir = args[1];
String inputPath = inputDir + args[2];
String outputPath = outputDir + args[3];
String userDir = args[4];
int topNUser = Integer.parseInt(args[5]);
double score_threshold = Double.parseDouble(args[6]);
int topNItem = Integer.parseInt(args[7]);
double weight_threshold = Double.parseDouble(args[8]);
*/
Map<String, List<Score>> scoreMap = FileTool.loadScoreMap(inputPath, false, ",", "userCF");
DataProcessByUserCF.sortScoreMap(scoreMap);
List<String> fileNameList = FileTool.traverseFolder(userDir);
//我选择推荐该user的最相似的5个user的前5个item
Map<String, Set<String>> predictMap = DataProcessByUserCF.predict(scoreMap, fileNameList, userDir, topNUser, score_threshold, topNItem, weight_threshold);
FileTool.initWriter1(outputPath);
DataProcessByUserCF.outputRecommendList(predictMap);
FileTool.closeWriter1();
scoreMap.clear();
}
}
基于item的协同过滤:map文件构建user-item评分矩阵并计算item间的相似度生成item-item的得分表:
package script.runByItemCF;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import service.CalculateSimilarityByItemCF;
import service.DataProcessByItemCF;
import util.FileTool;
import entity.User;
public class SpliteFileAndMakeScoreTable {
public static void main(String[] args) throws Exception {
/*
String inputDir = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/";
String outputDir = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/out/user/";
String userPath = inputDir + "train_als_itemCF1.csv";
String outputPath = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/out/score.csv";
String scoreTablePath = "D:/workspace/java/recommendSystemDemo/data/unicom/allRecommend/out/scoreTable.csv";
String token = ",";
*/
String inputDir = args[0];
String outputDir = args[1];
String userPath = inputDir + args[2];
String outputPath = args[3];
String scoreTablePath = args[4];
String token = args[5];
//FileTool.makeSampleData(userPath, true, outputPath, 10000);
//List<Object> itemList = FileTool.readFileOne(itemPath, true, ",", "item");
//List<User> userList = FileTool.readFileOne(userPath, true, ",", "user", "mapUser"); //sample2
//List<User> userList = FileTool.readFileOne(userPath, true, "::", "user", "mapUser"); //movielens
List<User> userList = FileTool.readFileOne(userPath, true, token, "user", "mapUser");
Set<String> userSet = new HashSet<String>();
Set<String> itemSet = new HashSet<String>();
Map<String, Map<String, List<User>>> userMap = DataProcessByItemCF.mapByUser(userList,userSet,itemSet);
userList.clear();
DataProcessByItemCF.output(userMap, outputDir);
//生成itemToUser的打分表
Map<String, Map<String, Double>> scoreTable = DataProcessByItemCF.makeScoreTable(userMap);
DataProcessByItemCF.output(scoreTable, scoreTablePath , userSet, itemSet, ",");
userMap.clear();
FileTool.initWriter1(outputPath);
CalculateSimilarityByItemCF.execute(scoreTable, userSet, itemSet);
FileTool.closeWriter1();
}
}
基于item的协同过滤,reduce文件,对users排序:
package script.runByItemCF;
import java.util.Collections;
import java.util.List;
import service.DataProcessByUserCF;
import util.FileTool;
import entity.User;
public class ReduceFileTest {
public static void main(String[] args) throws Exception {
/*
String inputDir = "data/fresh_comp_offline/sample2/out/user/";
String outputDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
*/
/*
String inputDir = "data/ml-1m/out/user/";
String outputDir = "data/ml-1m/out/sorteduser/";
*/
String inputDir = args[0];
String outputDir = args[1];
List<String> pathList = FileTool.traverseFolder(inputDir);
for(String path : pathList){
List<User> userList = FileTool.readFileOne(inputDir+path, false, "\t", "user", "reduceUser");
List<User> list = DataProcessByUserCF.reduceUserByItem(userList);
userList.clear();
FileTool.initWriter1(outputDir + path);
Collections.sort(list);
DataProcessByUserCF.outputUser(list);
FileTool.closeWriter1();
list.clear();
}
}
}
基于item的协同过滤,为用户进行推荐,生成预测列表:
package script.runByItemCF;
import java.util.List;
import java.util.Map;
import java.util.Set;
import service.DataProcessByItemCF;
import util.FileTool;
import entity.Score;
public class PredictTest {
public static void main(String[] args) throws Exception {
/*
String inputDir = "data/fresh_comp_offline/sample2/out/";
String outputDir = "data/fresh_comp_offline/sample2/out/";
String inputPath = inputDir + "score.csv";
String outputPath = outputDir + "predict";
String userDir = "data/fresh_comp_offline/sample2/out/sorteduser/";
int topNUser = 5;
double score_threshold = 0;
int topNItem = 5;
double weight_threshold = 0;
*/
/*
String inputDir = "data/ml-1m/out/out_itemCF/";
String outputDir = "data/ml-1m/out/out_itemCF/";
String inputPath = inputDir + "score.csv";
String outputPath = outputDir + "predict";
String userDir = "data/ml-1m/out/out_itemCF/sorteduser/";
int topNUser = 5;
double score_threshold = 0;
int topNItem = 5;
double weight_threshold = 0;
*/
String inputDir = args[0];
String outputDir = args[1];
String inputPath = inputDir + args[2];
String outputPath = outputDir + args[3];
String userDir = args[4];
int topNUser = Integer.parseInt(args[5]);
double score_threshold = Double.parseDouble(args[6]);
int topNItem = Integer.parseInt(args[7]);
double weight_threshold = Double.parseDouble(args[8]);
Map<String, List<Score>> scoreMap = FileTool.loadScoreMap(inputPath, false, ",", "itemCF");
DataProcessByItemCF.sortScoreMap(scoreMap);
List<String> fileNameList = FileTool.traverseFolder(userDir);
//我选择推荐该user的最相似的5个user的前5个item
Map<String, Set<String>> predictMap = DataProcessByItemCF.predict(scoreMap, fileNameList, userDir, topNUser, score_threshold, topNItem, weight_threshold);
FileTool.initWriter1(outputPath);
DataProcessByItemCF.outputRecommendList(predictMap);
FileTool.closeWriter1();
scoreMap.clear();
}
}
计算准确率、召回率、F测度值:
package script;
import java.util.List;
import java.util.Map;
import service.DataProcessByUserCF;
import util.FileTool;
public class MatchTest {
public static void main(String[] args) throws Exception {
//String predictPath = "data/ml-1m/out/predict_itemCF2";
//String predictPath = "data/ml-1m/out/predict_userCF";
//String predictPath = "data/ml-1m/out/predict_ALS_15_600_0.388_0.01";
//String predictPath = "data/ml-1m/out/predict_ml_ALS_15_600_0.388_0.01_1.csv";
//String predictPath = "data/unicom/outRecommend/out/out_outRecommedList3.csv";
//String predictPath = "data/unicom/allRecommend/out/out_outRecommedList250_0.4.csv";
//String predictPath = "data/unicom/allRecommend/out/predict_als_randomForest.csv";
String predictPath = "data/unicom/out/test_pred.txt";
//String predictPath = "data/ml-1m/out/predict_ml_userCF_RF.csv";
//String predictPath = "data/ml-1m/out/predict_ml_userCF.csv";
//String predictPath = "data/ml-1m/out/predict_ml_ALS.csv";
//String predictPath = "data/ml-1m/out/predict_ml_all1.csv";
//String predictPath = "data/ml-1m/out/predict_ml_ALS_15_600_0.388_0.01.csv";
//String referencePath = "data/ml-1m/sample/testData.dat";
//String referencePath = "data/unicom/outRecommend/test3_list.csv";
//String referencePath = "data/unicom/allRecommend/test_als.csv";
String referencePath = "data/unicom/out/test_real.txt";
Map<String, List<String>> predictMap = FileTool.loadPredictData(predictPath, false, ",", 50);
int predictN = FileTool.count;
FileTool.count = 0;
Map<String, List<String>> referenceMap = FileTool.loadPredictData(referencePath, false, ",", 10000);
int referenceN = FileTool.count;
DataProcessByUserCF.prediction(predictMap, predictN, referenceMap, referenceN);
}
}
脚本运行顺序:
(1).基于用户的协同过滤推荐:
1.script.runByUserCF.SpliteFileAndMakeScoreTable.java //map文件并生成user-user的score
2.script.runByUserCF.ReduceFileTest.java //对map后的文件排序,主要对user内的item的score排序
3.script.runByUserCF.PredictTest.java //预测,生成预测列表user-item
4.runMakeTestSet.sh //生成测试集
5.runSpliteFile.sh //对测试集文件进行map
6.runMatch.sh //将预测列表与测试集进行匹配,计算预测准确率及召回率
(2).基于item的协同过滤推荐:
1.script.runByItemCF.SpliteFileAndMakeScoreTable.java //map文件并生成item-item的score
2.script.runByItemCF.ReduceFileTest.java //对map后的文件排序,主要对user内的item的score排序
3.script.runByItemCF.PredictTest.java //预测,生成预测列表user-item
4.runMakeTestSet.sh //生成测试集
5.runSpliteFile.sh //对测试集文件进行map
6.runMatch.sh //将预测列表与测试集进行匹配,计算预测准确率及召回率
以上为核心代码,大家可以参考项目源代码地址:
https://download.csdn.net/download/u013473512/10141066
https://github.com/Emmitte/recommendSystem
欢迎关注“程序杂货铺”公众号,里面有精彩内容,欢迎大家收看^_^
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/125710.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...