优化版 JAVA 最大熵模型(GIS 训练)

本贴最后更新于 2071 天前,其中的信息可能已经渤澥桑田

网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。

     /**
   * 样本数据集
   */
  List<Instance> instanceList = new ArrayList<Instance>();
  /**
   * 特征列表,来自所有事件的统计结果
   */

// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;

   * 训练模型 * @param maxIt 最大迭代次数
   */public void train(int maxIt,String savePath) throws IOException {
      Map,Double> empiricalE = new HashMap<>(); // 经验期望
    Map,Double> modelE = new HashMap<>(); // 模型期望

    for (Map.Entry,Weight> e:weightMap.entrySet())
      {
          double ratio=(double) e.getValue().getCnt() / instanceList.size();
    empiricalE.put(e.getKey(),ratio);
    }
      Map,Double> lastWeight=new HashMap<>();
   for (int i = 0; i < maxIt; ++i)
      {
          System.out.println("iter:"+i);
    computeModeE(modelE);//计算模型期望
    System.out.println("model finish.updating...");
   for (Map.Entry,Weight> e:weightMap.entrySet())
          {
             //lastWeight[w] = weight[w];
    lastWeight.put(e.getKey(),e.getValue().getWeight());
    String f=e.getKey();
   double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
    weightMap.get(f).addWeight(delta);
    }
          System.out.println("saving iter:"+i);
    learningRate*=0.99;
    learningRate=learningRate<10?10:learningRate;
    saveParam(savePath+"ent_insopt.par"+i);
   if (checkConvergence(lastWeight, weightMap)) break;

    }

  }

  /**
   * 预测类别 * @param fieldList
    * @return
    */
  public Pair, Double>[] predict(Map,Integer> fieldList)
  {
      double[] prob = calProb(fieldList);
    Pair, Double>[] pairResult = new Pair[prob.length];
   for (int i = 0; i < prob.length; ++i)
      {
          pairResult[i] = new Pair, Double>(labels.get(i), prob[i]);
    }

      return pairResult;
  }

  /**
   * 检查是否收敛 * @param w1
    * @param w2
    * @return 是否收敛
   */public boolean checkConvergence(Map,Double> w1, Map,Weight> w2)
  {
      System.out.println("w1 size:"+w1.size());
   boolean flag=true;
   for (Map.Entry,Double> e1:w1.entrySet())
      {
          //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
    if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
    flag=false;
    }
      return flag;
  }

  /**
   * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
   */public void computeModeE(Map,Double> modelE)
  {
      modelE.clear();
   double rate=1.0 / instanceList.size();
   for (int i = 0; i < instanceList.size(); ++i)
      {

          Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
   //计算当前样本X对应所有类别的概率  double[] pro = calProb(fieldMap);
   for (Map.Entry,Integer> e:fieldMap.entrySet())
          {
              String insFeature=e.getKey();/**
     * 训练模型
     * @param maxIt 最大迭代次数
     */
    public void train(int maxIt,String savePath) throws IOException {
        Map<String,Double> empiricalE = new HashMap<>();   // 经验期望
        Map<String,Double> modelE = new HashMap<>();       // 模型期望

        for (Map.Entry<String,Weight> e:weightMap.entrySet())
        {
            double ratio=(double) e.getValue().getCnt() / instanceList.size();
            empiricalE.put(e.getKey(),ratio);
        }
        Map<String,Double> lastWeight=new HashMap<>();
        for (int i = 0; i < maxIt; ++i)
        {
            System.out.println("iter:"+i);
            computeModeE(modelE);//计算模型期望
            System.out.println("model finish.updating...");
            for (Map.Entry<String,Weight> e:weightMap.entrySet())
            {
               //lastWeight[w] = weight[w];
                lastWeight.put(e.getKey(),e.getValue().getWeight());
                String f=e.getKey();
                double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
                weightMap.get(f).addWeight(delta);
            }
            System.out.println("saving iter:"+i);
            learningRate*=0.99;
            learningRate=learningRate<10?10:learningRate;
            saveParam(savePath+"ent_insopt.par"+i);
            if (checkConvergence(lastWeight, weightMap)) break;

        }

    }

    /**
     * 预测类别
     * @param fieldList
     * @return
     */
    public Pair<String, Double>[] predict(Map<String,Integer> fieldList)
    {
        double[] prob = calProb(fieldList);
        Pair<String, Double>[] pairResult = new Pair[prob.length];
        for (int i = 0; i < prob.length; ++i)
        {
            pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]);
        }

        return pairResult;
    }

    /**
     * 检查是否收敛
     * @param w1
     * @param w2
     * @return 是否收敛
     */
    public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2)
    {
        System.out.println("w1 size:"+w1.size());
        boolean flag=true;
        for (Map.Entry<String,Double> e1:w1.entrySet())
        {
            //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
            if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
                flag=false;
        }
        return flag;
    }

    /**
     * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。
     * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
     */
    public void computeModeE(Map<String,Double> modelE)
    {
        modelE.clear();
        double rate=1.0 / instanceList.size();
        for (int i = 0; i < instanceList.size(); ++i)
        {

            Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
             //计算当前样本X对应所有类别的概率
            double[] pro = calProb(fieldMap);
            for (Map.Entry<String,Integer> e:fieldMap.entrySet())
            {
                String insFeature=e.getKey();
                int cnt=e.getValue();
                for (int k = 0; k < labels.size(); k++)
                {
                    String feature=labels.get(k)+":"+insFeature;
                      if (weightMap.containsKey(feature)) {
                        double  delta=pro[k] * rate*cnt;
                        modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
                    }
                }
            }
        }
    }
//    public class Mode implements Runnable
//    {
//        ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>();
//        boolean flag=true;
//        List<Instance> i∂
//        public void addIns(int i)
//        {
//
//        }
//
//        @Override
//        public void run() {
//            while(flag)
//            {
//                int ins=insQueue.poll();
//            }
//        }
//    }
    /**
     * 计算p(y|x),此时的x指的是instance里的field
     * @param fieldList 实例的特征列表
     * @return 该实例属于每个类别的概率
     */
    public double[] calProb(Map<String,Integer> fieldList)
    {
        double[] p = new double[labels.size()];
        double sum = 0;  // 正则化因子,保证概率和为1
        for (int i = 0; i < labels.size(); ++i)
        {
            double weightSum = 0;
            String label=labels.get(i);
            for (String field : fieldList.keySet())
            {
                String feature=label+":"+field;
                 if (weightMap.containsKey(feature)) {
                    weightSum += weightMap.get(feature).getWeight()*fieldList.get(field);
                }
            }
            if(weightSum>15)
            {
                weightSum=15;
            }
            p[i] = Math.exp(weightSum);

            sum += p[i];
        }
        //System.out.println();
        for (int i = 0; i < p.length; ++i)
        {
            p[i] /= sum;
//            if(Double.isNaN(p[i]))
//            {
//                System.out.println(p[i]);
//            }
        }
        return p;
    }

    /**
     * 一个观测实例,包含事件和时间发生的环境
     */
    class Instance implements Serializable
    {
        /**
         * 事件(类别),如Outdoor
         */
        String label;
        /**
         * 事件发生的环境集合,如[Sunny, Happy]
         */
        Map<String,Integer> fieldList = new HashMap<>();

        public Instance(String label, Map<String,Integer>fieldList)
        {
            this.label = label;
            this.fieldList = fieldList;
        }
    }

    /**
     * 特征(二值函数)
     */
    class Weight
    {
        double weight=0;
        int cnt=0;
        public void addWeight(double w)
        {
            weight+=(w);
        }
        public double getWeight() {
            return weight;
        }
        public void addCnt(int c)
        {
            cnt+=c;
        }
        public void setWeight(double weight) {
            this.weight = weight;
        }

        public int getCnt() {
            return cnt;
        }

        public void setCnt(int cnt) {
            this.cnt = cnt;
        }
    }
   int cnt=e.getValue();
   for (int k = 0; k < labels.size(); k++)
              {
                  String feature=labels.get(k)+":"+insFeature;
   if (weightMap.containsKey(feature)) {
                      double delta=pro[k] * rate*cnt;
    modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
    }
              }
          }
      }
  }
  • 最大熵模型
    1 引用
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3167 引用 • 8207 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • Vditor

    Vditor 是一款浏览器端的 Markdown 编辑器,支持所见即所得、即时渲染(类似 Typora)和分屏预览模式。它使用 TypeScript 实现,支持原生 JavaScript、Vue、React 和 Angular。

    310 引用 • 1666 回帖
  • 互联网

    互联网(Internet),又称网际网络,或音译因特网、英特网。互联网始于 1969 年美国的阿帕网,是网络与网络之间所串连成的庞大网络,这些网络以一组通用的协议相连,形成逻辑上的单一巨大国际网络。

    96 引用 • 330 回帖
  • Latke

    Latke 是一款以 JSON 为主的 Java Web 框架。

    70 引用 • 532 回帖 • 711 关注
  • Chrome

    Chrome 又称 Google 浏览器,是一个由谷歌公司开发的网页浏览器。该浏览器是基于其他开源软件所编写,包括 WebKit,目标是提升稳定性、速度和安全性,并创造出简单且有效率的使用者界面。

    60 引用 • 287 回帖
  • 微服务

    微服务架构是一种架构模式,它提倡将单一应用划分成一组小的服务。服务之间互相协调,互相配合,为用户提供最终价值。每个服务运行在独立的进程中。服务于服务之间才用轻量级的通信机制互相沟通。每个服务都围绕着具体业务构建,能够被独立的部署。

    96 引用 • 155 回帖 • 1 关注
  • jQuery

    jQuery 是一套跨浏览器的 JavaScript 库,强化 HTML 与 JavaScript 之间的操作。由 John Resig 在 2006 年 1 月的 BarCamp NYC 上释出第一个版本。全球约有 28% 的网站使用 jQuery,是非常受欢迎的 JavaScript 库。

    63 引用 • 134 回帖 • 745 关注
  • Hibernate

    Hibernate 是一个开放源代码的对象关系映射框架,它对 JDBC 进行了非常轻量级的对象封装,使得 Java 程序员可以随心所欲的使用对象编程思维来操纵数据库。

    39 引用 • 103 回帖 • 682 关注
  • 负能量

    上帝为你关上了一扇门,然后就去睡觉了....努力不一定能成功,但不努力一定很轻松 (° ー °〃)

    85 引用 • 1201 回帖 • 455 关注
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3167 引用 • 8207 回帖
  • HHKB

    HHKB 是富士通的 Happy Hacking 系列电容键盘。电容键盘即无接点静电电容式键盘(Capacitive Keyboard)。

    5 引用 • 74 回帖 • 405 关注
  • uTools

    uTools 是一个极简、插件化、跨平台的现代桌面软件。通过自由选配丰富的插件,打造你得心应手的工具集合。

    5 引用 • 13 回帖
  • 正则表达式

    正则表达式(Regular Expression)使用单个字符串来描述、匹配一系列遵循某个句法规则的字符串。

    31 引用 • 94 回帖
  • Mac

    Mac 是苹果公司自 1984 年起以“Macintosh”开始开发的个人消费型计算机,如:iMac、Mac mini、Macbook Air、Macbook Pro、Macbook、Mac Pro 等计算机。

    164 引用 • 594 回帖 • 1 关注
  • OpenShift

    红帽提供的 PaaS 云,支持多种编程语言,为开发人员提供了更为灵活的框架、存储选择。

    14 引用 • 20 回帖 • 602 关注
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 551 关注
  • Firefox

    Mozilla Firefox 中文俗称“火狐”(正式缩写为 Fx 或 fx,非正式缩写为 FF),是一个开源的网页浏览器,使用 Gecko 排版引擎,支持多种操作系统,如 Windows、OSX 及 Linux 等。

    7 引用 • 30 回帖 • 454 关注
  • flomo

    flomo 是新一代 「卡片笔记」 ,专注在碎片化时代,促进你的记录,帮你积累更多知识资产。

    3 引用 • 80 回帖 • 1 关注
  • Oracle

    Oracle(甲骨文)公司,全称甲骨文股份有限公司(甲骨文软件系统有限公司),是全球最大的企业级软件公司,总部位于美国加利福尼亚州的红木滩。1989 年正式进入中国市场。2013 年,甲骨文已超越 IBM,成为继 Microsoft 后全球第二大软件公司。

    103 引用 • 126 回帖 • 452 关注
  • TensorFlow

    TensorFlow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组,即张量(tensor)。

    20 引用 • 19 回帖 • 1 关注
  • 创造

    你创造的作品可能会帮助到很多人,如果是开源项目的话就更赞了!

    172 引用 • 990 回帖
  • Google

    Google(Google Inc.,NASDAQ:GOOG)是一家美国上市公司(公有股份公司),于 1998 年 9 月 7 日以私有股份公司的形式创立,设计并管理一个互联网搜索引擎。Google 公司的总部称作“Googleplex”,它位于加利福尼亚山景城。Google 目前被公认为是全球规模最大的搜索引擎,它提供了简单易用的免费服务。不作恶(Don't be evil)是谷歌公司的一项非正式的公司口号。

    49 引用 • 192 回帖
  • Postman

    Postman 是一款简单好用的 HTTP API 调试工具。

    4 引用 • 3 回帖 • 2 关注
  • Angular

    AngularAngularJS 的新版本。

    26 引用 • 66 回帖 • 511 关注
  • SendCloud

    SendCloud 由搜狐武汉研发中心孵化的项目,是致力于为开发者提供高质量的触发邮件服务的云端邮件发送平台,为开发者提供便利的 API 接口来调用服务,让邮件准确迅速到达用户收件箱并获得强大的追踪数据。

    2 引用 • 8 回帖 • 438 关注
  • 分享

    有什么新发现就分享给大家吧!

    242 引用 • 1746 回帖 • 2 关注
  • Sandbox

    如果帖子标签含有 Sandbox ,则该帖子会被视为“测试帖”,主要用于测试社区功能,排查 bug 等,该标签下内容不定期进行清理。

    368 引用 • 1212 回帖 • 578 关注
  • 快应用

    快应用 是基于手机硬件平台的新型应用形态;标准是由主流手机厂商组成的快应用联盟联合制定;快应用标准的诞生将在研发接口、能力接入、开发者服务等层面建设标准平台;以平台化的生态模式对个人开发者和企业开发者全品类开放。

    15 引用 • 127 回帖 • 1 关注