Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

为EM 算法Demo中新增坐标系打点,使算法结果更加直观 #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions StatisticalLearning/DataMining_EM/Client.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package DataMining_EM;


/**
* EM��������㷨����������
* EM期望最大化算法场景调用类
* @author lyq
*
*/
public class Client {

public static void main(String[] args){
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";

String filePath = "D:\\Eclipse_Workstation\\DataMining_EM\\src\\input.txt";

EMTool tool = new EMTool(filePath);
tool.readDataFile();
Expand Down
177 changes: 177 additions & 0 deletions StatisticalLearning/DataMining_EM/DrawPoints.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@

import java.awt.BasicStroke;
import java.awt.BorderLayout;
import java.awt.Canvas;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Image;
import java.util.ArrayList;

import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.WindowConstants;

/**
*
* @author 樊俊彬
* @Time 2014-1-1
* @modify mindcont
* @date 2016-3-26
* @since 配合 EM 算法进行绘制坐标并打点
*
*/
public class DrawPoints extends JFrame {

private static final long serialVersionUID = 1L;
private Image iBuffer;

// 框架起点坐标、宽高
private final int FRAME_X = 50;
private final int FRAME_Y = 50;
private final int FRAME_WIDTH = 500;
private final int FRAME_HEIGHT = 500;

// 原点坐标
private final int Origin_X = FRAME_X + 40;
private final int Origin_Y = FRAME_Y + FRAME_HEIGHT - 30;

// X轴、Y轴终点坐标
private final int XAxis_X = FRAME_X + FRAME_WIDTH - 30;
private final int XAxis_Y = Origin_Y;
private final int YAxis_X = Origin_X;
private final int YAxis_Y = FRAME_Y + 30;

//坐标轴间隔
private final int INTERVAL = 20;


// 保存Point对象的X Y 坐标
private int[] Coordinate_X = new int [50];
private int[] Coordinate_Y = new int [50];


public DrawPoints(ArrayList<Point> points) {
super("EM Demo");
this.setDefaultCloseOperation(EXIT_ON_CLOSE);
this.setBounds(300, 100, 600, 650);

// 添加控制到框架北部区
JPanel topPanel = new JPanel();
this.add(topPanel, BorderLayout.NORTH);

// 文本框
topPanel.add(new JLabel("EM Demo", JLabel.CENTER));

//坐标点数据列表中读取 X轴 Y轴的坐标值 分别赋值 给 Coordinate_X Coordinate_Y
for(int i=0;i<points.size();i++){
Point point = points.get(i);

Coordinate_X[i]=point.getX();
Coordinate_Y[i]=point.getY();
}

// 添加画布到中央区
MyCanvas ChartCanvas = new MyCanvas();
this.add(ChartCanvas, BorderLayout.CENTER);
this.setResizable(false);
this.setVisible(true);
setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
}


/**}
* 画布绘制坐标系 并打点
*/
class MyCanvas extends Canvas {

private static final long serialVersionUID = 1L;
public void paint(Graphics g) {
Graphics2D g2D = (Graphics2D) g;

// 画边框
g.setColor(Color.BLACK);
g.draw3DRect(FRAME_X, FRAME_Y, FRAME_WIDTH, FRAME_HEIGHT, true);

// 画坐标轴
g.setColor(Color.BLACK);
g2D.setStroke(new BasicStroke(Float.parseFloat("2.0f")));

// X轴及方向箭头
g.drawLine(Origin_X, Origin_Y, XAxis_X, XAxis_Y);
g.drawLine(XAxis_X, XAxis_Y, XAxis_X - 5, XAxis_Y - 5);
g.drawLine(XAxis_X, XAxis_Y, XAxis_X - 5, XAxis_Y + 5);

// Y轴及方向箭头
g.drawLine(Origin_X, Origin_Y, YAxis_X, YAxis_Y);
g.drawLine(YAxis_X, YAxis_Y, YAxis_X - 5, YAxis_Y + 5);
g.drawLine(YAxis_X, YAxis_Y, YAxis_X + 5, YAxis_Y + 5);

// 画X轴上刻度
g.setColor(Color.BLUE);
g2D.setStroke(new BasicStroke(Float.parseFloat("1.0f")));
for (int i = Origin_X + 15, j = 0; i < XAxis_X; i += INTERVAL, j += 20) {
g.drawString(j + "", i - 20, Origin_Y + 20);

}
g.drawString("X轴", XAxis_X + 5, XAxis_Y + 5);

// 画Y轴上刻度
for (int i = Origin_Y, j = 0; i > YAxis_Y; i -= INTERVAL, j += 20) {
g.drawString(j + "", Origin_X - 20, i + 3);
}
g.drawString("Y轴", YAxis_X - 5, YAxis_Y - 5);

// 画网格线
g.setColor(Color.BLACK);
// 横线
for (int i = Origin_Y - INTERVAL; i > YAxis_Y; i -= INTERVAL) {
g.drawLine(Origin_X, i, Origin_X + 21 * INTERVAL, i);
}
// 竖线
for (int i = Origin_X + INTERVAL; i < XAxis_X; i += INTERVAL) {
g.drawLine(i, Origin_Y, i, Origin_Y - 21 * INTERVAL);

}

//设置画笔颜色为绿色
g.setColor(Color.green);
g2D.setStroke(new BasicStroke(Float.parseFloat("5.0f")));
//画出 簇点
g.drawOval(Origin_X+Coordinate_X[0], Origin_Y -Coordinate_Y[0], 5, 5);
g.drawOval(Origin_X+Coordinate_X[1], Origin_Y -Coordinate_Y[1], 5, 5);

//设置画笔颜色 为红色
g.setColor(Color.red);
g2D.setStroke(new BasicStroke(Float.parseFloat("5.0f")));
//画其余各点
for (int i = 2; i < Coordinate_X.length ;i++) {
g.drawLine(Origin_X+ Coordinate_X[i],
Origin_Y - Coordinate_Y[i],
Origin_X+ Coordinate_X[i],
Origin_Y - Coordinate_Y[i]);
}


}

// 双缓冲技术解决图像显示问题
public void update(Graphics g) {
if (iBuffer == null) {
iBuffer = createImage(this.getSize().width,
this.getSize().height);

}
Graphics gBuffer = iBuffer.getGraphics();
gBuffer.setColor(getBackground());
gBuffer.fillRect(0, 0, this.getSize().width, this.getSize().height);
paint(gBuffer);
gBuffer.dispose();
g.drawImage(iBuffer, 0, 0, this);
}
}


}

82 changes: 44 additions & 38 deletions StatisticalLearning/DataMining_EM/EMTool.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
package DataMining_EM;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
Expand All @@ -8,32 +6,37 @@
import java.util.ArrayList;

/**
* EM��������㷨������
* EM最大期望算法工具类
*
* @author lyq
* @modify mindcont
* @date 2016-3-26
* @since 新增 坐标系 打点
*
*/
public class EMTool {
// ���������ļ���ַ
// 测试数据文件地址
private String dataFilePath;
// �������������
// 测试坐标点数据
private String[][] data;
// ��������������б�
// 测试坐标点数据列表
private ArrayList<Point> pointArray;
// Ŀ��C1��
// 目标C1点
private Point p1;
// Ŀ��C2��
// 目标C2点
private Point p2;

public EMTool(String dataFilePath) {
this.dataFilePath = dataFilePath;
pointArray = new ArrayList<>();
}

/**
* ���ļ��ж�ȡ����
* 从文件中读取数据
*/
public void readDataFile() {


File file = new File(dataFilePath);
ArrayList<String[]> dataArray = new ArrayList<String[]>();

Expand All @@ -53,64 +56,65 @@ public void readDataFile() {
data = new String[dataArray.size()][];
dataArray.toArray(data);

// ��ʼʱĬ��ȡͷ2������Ϊ2��������
// 开始时默认取头2个点作为2个簇中心
p1 = new Point(Integer.parseInt(data[0][0]),
Integer.parseInt(data[0][1]));
p2 = new Point(Integer.parseInt(data[1][0]),
Integer.parseInt(data[1][1]));

Point p;
for (String[] array : data) {
// ������ת��Ϊ��������б��������
// 将数据转换为对象加入列表方便计算
p = new Point(Integer.parseInt(array[0]),
Integer.parseInt(array[1]));
pointArray.add(p);
}

}


/**
* ������������2�������ĵ��������
* 计算坐标点对于2个簇中心点的隶属度
*
* @param p
* �����������
* 待测试坐标点
*/
private void computeMemberShip(Point p) {
// p������һ�������ĵ�ľ���
// p点距离第一个簇中心点的距离
double distance1 = 0;
// p����ڶ������ĵ�ľ���
// p距离第二个中心点的距离
double distance2 = 0;

// ��ŷʽ�������
// 用欧式距离计算
distance1 = Math.pow(p.getX() - p1.getX(), 2)
+ Math.pow(p.getY() - p1.getY(), 2);
distance2 = Math.pow(p.getX() - p2.getX(), 2)
+ Math.pow(p.getY() - p2.getY(), 2);

// �������p1��������ȣ������ɷ��ȹ�ϵ�����뿿��ԽС��������Խ������Ҫ�ô��distance2����ľ�������ʾ
p.setMemberShip1(distance2 / (distance1 + distance2));
// �������p2���������
p.setMemberShip2(distance1 / (distance1 + distance2));
// 计算对于p1点的隶属度,与距离成反比关系,距离靠近越小,隶属度越大,所以要用大的distance2另外的距离来表示
p.setMemberShip1((int) (distance2 / (distance1 + distance2)));
// 计算对于p2点的隶属度
p.setMemberShip2((int) (distance1 / (distance1 + distance2)));
}

/**
* ִ��������󻯲���
* 执行期望最大化步骤
*/
public void exceptMaxStep() {
// �µ��Ż����Ĵ����ĵ�
double p1X = 0;
double p1Y = 0;
double p2X = 0;
double p2Y = 0;
double temp1 = 0;
double temp2 = 0;
// ���ֵ
double errorValue1 = 0;
double errorValue2 = 0;
// �ϴθ��µĴص�����
// 新的优化过的簇中心点
int p1X = 0;
int p1Y = 0;
int p2X = 0;
int p2Y = 0;
int temp1 = 0;
int temp2 = 0;
// 误差值
int errorValue1 = 0;
int errorValue2 = 0;
// 上次更新的簇点坐标
Point lastP1 = null;
Point lastP2 = null;

// ����ʼ�����ʱ�򣬻������ĵ�����ֵ����1��ʱ����Ҫ�ٴε�������
// 当开始计算的时候,或是中心点的误差值超过1的时候都需要再次迭代计算
while (lastP1 == null || errorValue1 > 1.0 || errorValue2 > 1.0) {
for (Point p : pointArray) {
computeMemberShip(p);
Expand All @@ -126,7 +130,7 @@ public void exceptMaxStep() {
lastP1 = new Point(p1.getX(), p1.getY());
lastP2 = new Point(p2.getX(), p2.getY());

// �׹�ʽ�����µĴ����ĵ�����,����󻯴���
// 套公式计算新的簇中心点坐标,最最大化处理
p1.setX(p1X / temp1);
p1.setY(p1Y / temp1);
p2.setX(p2X / temp2);
Expand All @@ -139,8 +143,10 @@ public void exceptMaxStep() {
}

System.out.println(MessageFormat.format(
"�����Ľڵ�p1({0}, {1}), p2({2}, {3})", p1.getX(), p1.getY(),
"簇中心节点p1({0}, {1}), p2({2}, {3})", p1.getX(), p1.getY(),
p2.getX(), p2.getY()));

new DrawPoints(pointArray);//调用DrawPoints类 绘制坐标系并打点
}

}
Loading